6 Commits

11 changed files with 261 additions and 39 deletions

View File

@ -4,8 +4,8 @@ IMAGE_NAME=robocar-steering-tflite-edgetpu
TAG=$(git describe)
FULL_IMAGE_NAME=docker.io/cyrilix/${IMAGE_NAME}:${TAG}
BINARY=rc-steering
TFLITE_VERSION=2.6.0
GOLANG_VERSION=1.18
TFLITE_VERSION=2.10.0
GOLANG_VERSION=1.19
GOTAGS="-tags netgo"
BUILDER_CONTAINER="${IMAGE_NAME}-builder"
@ -35,6 +35,20 @@ image_build_binaries(){
buildah run $containerName ln -s /usr/lib/aarch64-linux-gnu/libedgetpu.so.1 /usr/lib/aarch64-linux-gnu/libedgetpu.so
printf "Compile for linux/amd64\n"
LIB_ARCH=x86_64-linux-gnu
LIB_FLAGS="-L /usr/local/lib/${LIB_ARCH} \
-L/usr/local/lib/${LIB_ARCH}/absl/base -labsl_base -labsl_throw_delegate -labsl_raw_logging_internal -labsl_spinlock_wait -labsl_malloc_internal -labsl_log_severity -labsl_strerror \
-L/usr/local/lib/${LIB_ARCH}/absl/status -labsl_status \
-L/usr/local/lib/${LIB_ARCH}/absl/hash -labsl_hash -labsl_city -labsl_low_level_hash \
-L/usr/local/lib/${LIB_ARCH}/absl/flags -labsl_flags -labsl_flags_internal -labsl_flags_marshalling -labsl_flags_reflection -labsl_flags_config -labsl_flags_program_name -labsl_flags_private_handle_accessor -labsl_flags_commandlineflag -labsl_flags_commandlineflag_internal\
-L/usr/local/lib/${LIB_ARCH}/absl/types -labsl_bad_variant_access -labsl_bad_optional_access -labsl_bad_any_cast_impl \
-L/usr/local/lib/${LIB_ARCH}/absl/strings -labsl_strings -labsl_str_format_internal -labsl_cord -labsl_cordz_info -labsl_cord_internal -labsl_cordz_functions -labsl_cordz_handle -labsl_strings_internal \
-L/usr/local/lib/${LIB_ARCH}/absl/time -labsl_time -labsl_time_zone -labsl_civil_time \
-L/usr/local/lib/${LIB_ARCH}/absl/numeric -labsl_int128 \
-L/usr/local/lib/${LIB_ARCH}/absl/synchronization -labsl_synchronization -labsl_graphcycles_internal\
-L/usr/local/lib/${LIB_ARCH}/absl/debugging -labsl_stacktrace -labsl_symbolize -labsl_debugging_internal -labsl_demangle_internal \
-L/usr/local/lib/${LIB_ARCH}/absl/profiling -labsl_exponential_biased \
-L/usr/local/lib/${LIB_ARCH}/absl/container -labsl_raw_hash_set -labsl_hashtablez_sampler"
buildah run \
--env CGO_ENABLED=1 \
--env CC=gcc \
@ -43,25 +57,27 @@ image_build_binaries(){
--env GOARCH=amd64 \
--env GOARM=${GOARM} \
--env CGO_CPPFLAGS="-I/usr/local/include" \
--env CGO_LDFLAGS="-L /usr/local/lib/x86_64-linux-gnu -L /usr/lib/x86_64-linux-gnu" \
--env CGO_LDFLAGS="${LIB_FLAGS}" \
$containerName \
go build -a -o rc-steering.amd64 ./cmd/rc-steering
#--env CGO_CXXFLAGS="--std=c++1z" \
printf "Compile for linux/arm/v7\n"
buildah run \
--env CGO_ENABLED=1 \
--env CC=arm-linux-gnueabihf-gcc \
--env CXX=arm-linux-gnueabihf-g++ \
--env GOOS=linux \
--env GOARCH=arm \
--env GOARM=7 \
--env CGO_CPPFLAGS="-I/usr/local/include" \
--env CGO_LDFLAGS="-L /usr/lib/arm-linux-gnueabihf -L /usr/local/lib/arm-linux-gnueabihf" \
$containerName \
go build -a -o rc-steering.armhf ./cmd/rc-steering
printf "Compile for linux/arm64\n"
LIB_ARCH=aarch64-linux-gnu
LIB_FLAGS="-L /usr/local/lib/${LIB_ARCH} \
-L/usr/local/lib/${LIB_ARCH}/absl/base -labsl_base -labsl_throw_delegate -labsl_raw_logging_internal -labsl_spinlock_wait -labsl_malloc_internal -labsl_log_severity -labsl_strerror\
-L/usr/local/lib/${LIB_ARCH}/absl/status -labsl_status \
-L/usr/local/lib/${LIB_ARCH}/absl/hash -labsl_hash -labsl_city -labsl_low_level_hash \
-L/usr/local/lib/${LIB_ARCH}/absl/flags -labsl_flags -labsl_flags_internal -labsl_flags_marshalling -labsl_flags_reflection -labsl_flags_config -labsl_flags_program_name -labsl_flags_private_handle_accessor -labsl_flags_commandlineflag -labsl_flags_commandlineflag_internal\
-L/usr/local/lib/${LIB_ARCH}/absl/types -labsl_bad_variant_access -labsl_bad_optional_access -labsl_bad_any_cast_impl \
-L/usr/local/lib/${LIB_ARCH}/absl/strings -labsl_strings -labsl_str_format_internal -labsl_cord -labsl_cordz_info -labsl_cord_internal -labsl_cordz_functions -labsl_cordz_handle -labsl_strings_internal \
-L/usr/local/lib/${LIB_ARCH}/absl/time -labsl_time -labsl_time_zone -labsl_civil_time \
-L/usr/local/lib/${LIB_ARCH}/absl/numeric -labsl_int128 \
-L/usr/local/lib/${LIB_ARCH}/absl/synchronization -labsl_synchronization -labsl_graphcycles_internal\
-L/usr/local/lib/${LIB_ARCH}/absl/debugging -labsl_stacktrace -labsl_symbolize -labsl_debugging_internal -labsl_demangle_internal \
-L/usr/local/lib/${LIB_ARCH}/absl/profiling -labsl_exponential_biased \
-L/usr/local/lib/${LIB_ARCH}/absl/container -labsl_raw_hash_set -labsl_hashtablez_sampler"
buildah run \
--env CGO_ENABLED=1 \
--env CC=aarch64-linux-gnu-gcc \
@ -69,7 +85,7 @@ image_build_binaries(){
--env GOOS=linux \
--env GOARCH=arm64 \
--env CGO_CPPFLAGS="-I/usr/local/include" \
--env CGO_LDFLAGS="-L /usr/lib/aarch64-linux-gnu -L /usr/local/lib/aarch64-linux-gnu" \
--env CGO_LDFLAGS="${LIB_FLAGS}" \
$containerName \
go build -a -o rc-steering.arm64 ./cmd/rc-steering
}
@ -110,7 +126,6 @@ image_build_binaries
image_build linux/amd64
image_build linux/arm64
image_build linux/arm/v7
# push image

View File

@ -3,18 +3,26 @@ package main
import (
"context"
"flag"
"fmt"
"github.com/cyrilix/robocar-base/cli"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"go.uber.org/zap"
"log"
"os"
"regexp"
"strconv"
)
const (
DefaultClientId = "robocar-steering-tflite-edgetpu"
)
var (
modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
)
func main() {
var mqttBroker, username, password, clientId string
var cameraTopic, steeringTopic string
@ -31,8 +39,8 @@ func main() {
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
flag.Parse()
@ -63,19 +71,37 @@ func main() {
flag.PrintDefaults()
os.Exit(1)
}
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
if err != nil {
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
}
if imgWidth == 0 {
imgWidth = width
}
if imgHeight == 0 {
imgHeight = height
}
if horizonFromName == 0 {
horizon = horizonFromName
}
if imgWidth <= 0 || imgHeight <= 0 {
zap.L().Error("img-width and img-height are mandatory")
flag.PrintDefaults()
os.Exit(1)
}
zap.S().Infof("model type : %v", modelType)
zap.S().Infof("model for image width : %v", imgWidth)
zap.S().Infof("model for image height: %v", imgHeight)
zap.S().Infof("model with horizon : %v", horizon)
client, err := cli.Connect(mqttBroker, username, password, clientId)
if err != nil {
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
}
defer client.Disconnect(50)
p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
defer p.Stop()
cli.HandleExit(p)
@ -85,3 +111,33 @@ func main() {
zap.L().Fatal("unable to start service", zap.Error(err))
}
}
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
match := modelNameRegex.FindStringSubmatch(modelPath)
results := map[string]string{}
for i, name := range match {
results[modelNameRegex.SubexpNames()[i]] = name
}
modelType = tools.ParseModelType(results["type"])
if modelType == tools.ModelTypeUnknown {
err = fmt.Errorf("unknown model type '%v'", results["type"])
return
}
imgWidth, err = strconv.Atoi(results["imgWidth"])
if err != nil {
err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
return
}
imgHeight, err = strconv.Atoi(results["imgHeight"])
if err != nil {
err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
return
}
horizon, err = strconv.Atoi(results["horizon"])
if err != nil {
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
return
}
return
}

View File

@ -0,0 +1,97 @@
package main
import (
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"testing"
)
func Test_parseModelName(t *testing.T) {
type args struct {
modelPath string
}
tests := []struct {
name string
args args
wantModelType tools.ModelType
wantImgWidth int
wantImgHeight int
wantHorizon int
wantErr bool
}{
{
name: "categorical",
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
wantModelType: tools.ModelTypeCategorical,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "linear",
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
wantModelType: tools.ModelTypeLinear,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "bad-model",
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-width",
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-height",
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-horizon",
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
if (err != nil) != tt.wantErr {
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotModelType != tt.wantModelType {
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
}
if gotImgWidth != tt.wantImgWidth {
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
}
if gotImgHeight != tt.wantImgHeight {
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
}
if gotHorizon != tt.wantHorizon {
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
}
})
}
}

4
go.mod
View File

@ -1,13 +1,13 @@
module github.com/cyrilix/robocar-steering-tflite-edgetpu
go 1.18
go 1.19
require (
github.com/cyrilix/robocar-base v0.1.7
github.com/cyrilix/robocar-protobuf/go v1.0.5
github.com/disintegration/imaging v1.6.2
github.com/eclipse/paho.mqtt.golang v1.4.1
github.com/mattn/go-tflite v1.0.2
github.com/mattn/go-tflite v1.0.4
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v0.30.0
go.opentelemetry.io/otel/metric v0.30.0
go.opentelemetry.io/otel/sdk/metric v0.30.0

5
go.sum
View File

@ -1,7 +1,6 @@
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/cyrilix/robocar-base v0.1.7 h1:EVzZ0KjigSFpke5f3A/PybEH3WFUEIrYSc3z/dhOZ48=
github.com/cyrilix/robocar-base v0.1.7/go.mod h1:4E11HQSNy2NT8e7MW188y6ST9C0RzarKyn7sK/3V/Lk=
github.com/cyrilix/robocar-protobuf/go v1.0.5 h1:PX1At+pf6G7gJwT4LzJLQu3/LPFTTNNlZmZSYtnSELY=
@ -31,8 +30,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-tflite v1.0.2 h1:P9CKqjyRSRM31SfL65WklD8U5B/iPD4CJQiRkB8K02g=
github.com/mattn/go-tflite v1.0.2/go.mod h1:2NwhEYXoP8vxRIpu95DElqMkZoV39ABRPF3AETN7N1w=
github.com/mattn/go-tflite v1.0.4 h1:wpfNKjMr3IJz4xI+oUeHE70RU6Q5dZc0FK/X8vCWLAo=
github.com/mattn/go-tflite v1.0.4/go.mod h1:j7bVlVHgKURK0p7AQOw3OqlGE2SVXqck7JsJo4wI+bc=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@ -19,9 +19,10 @@ import (
"time"
)
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
return &Part{
client: client,
modelType: modelType,
modelPath: modelPath,
steeringTopic: steeringTopic,
cameraTopic: cameraTopic,
@ -42,6 +43,7 @@ type Part struct {
options *tflite.InterpreterOptions
interpreter *tflite.Interpreter
modelType tools.ModelType
modelPath string
model *tflite.Model
edgeVebosity int
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
output := p.interpreter.GetOutputTensor(0).UInt8s()
zap.L().Debug("raw steering", zap.Uint8s("result", output))
steering, score := tools.LinearBin(output, 15, -1, 2.0)
var steering, score float64
switch p.modelType {
case tools.ModelTypeCategorical:
steering, score = tools.LinearBin(output, 15, -1, 2.0)
case tools.ModelTypeLinear:
steering = 2*(float64(output[0])/255.) - 1.
score = 0.6
}
zap.L().Debug("found steering",
zap.Float64("steering", steering),
zap.Float64("score", score),

View File

@ -1,9 +1,39 @@
package tools
import (
"fmt"
"go.uber.org/zap"
"sort"
"strings"
)
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
// LinearBin perform inverse linear_bin, taking
@ -15,19 +45,11 @@ func LinearBin(arr []uint8, n int, offset int, r float64) (float64, float64) {
}
var results []result
minScore := 0.2
for i := 0; i < outputSize; i++ {
score := float64(int(arr[i])) / 255.0
if score < minScore {
continue
}
results = append(results, result{score: score, index: i})
}
if len(results) == 0 {
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
return 0., 0.
}
zap.S().Debugf("raw result: %v", results)
sort.Slice(results, func(i, j int) bool {

View File

@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
})
}
}
func TestParseModelType(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want ModelType
}{
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseModelType(tt.args.s); got != tt.want {
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -92,5 +92,5 @@ rm -Rf edgetpu
MIT
## Author
Yasuhrio Matsumoto (a.k.a. mattn)
Yasuhiro Matsumoto (a.k.a. mattn)

View File

@ -32,7 +32,6 @@ _make_registration(void* o_init, void* o_free, void* o_prepare, void* o_invoke,
}
static void look_context(TfLiteContext *context) {
context->tensors;
TfLiteIntArray *plan = NULL;
context->GetExecutionPlan(context, &plan);
if (plan == NULL) return;

2
vendor/modules.txt vendored
View File

@ -28,7 +28,7 @@ github.com/gorilla/websocket
# github.com/mattn/go-pointer v0.0.1
## explicit
github.com/mattn/go-pointer
# github.com/mattn/go-tflite v1.0.2
# github.com/mattn/go-tflite v1.0.4
## explicit; go 1.13
github.com/mattn/go-tflite
github.com/mattn/go-tflite/delegates