diff --git a/cmd/rc-steering/rc-steering.go b/cmd/rc-steering/rc-steering.go index 9b93a16..be35ecf 100644 --- a/cmd/rc-steering/rc-steering.go +++ b/cmd/rc-steering/rc-steering.go @@ -3,18 +3,26 @@ package main import ( "context" "flag" + "fmt" "github.com/cyrilix/robocar-base/cli" "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics" "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering" + "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools" "go.uber.org/zap" "log" "os" + "regexp" + "strconv" ) const ( DefaultClientId = "robocar-steering-tflite-edgetpu" ) +var ( + modelNameRegex = regexp.MustCompile(".*model_(?P(categorical)|(linear))_(?P\\d+)x(?P\\d+)h(?P\\d+)_edgetpu.tflite$") +) + func main() { var mqttBroker, username, password, clientId string var cameraTopic, steeringTopic string @@ -31,8 +39,8 @@ func main() { flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set") flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set") flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity") - flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)") - flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)") + flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model") + flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model") flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'") logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") flag.Parse() @@ -63,19 +71,37 @@ func main() { flag.PrintDefaults() os.Exit(1) } + modelType, width, height, horizonFromName, err := parseModelName(modelPath) + if err != nil { + zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err) + } + if imgWidth == 0 { + imgWidth = width + } + if imgHeight == 0 { + imgHeight = height + } + if horizonFromName == 0 { + horizon = horizonFromName + } if imgWidth <= 0 || imgHeight <= 0 { zap.L().Error("img-width and img-height are mandatory") flag.PrintDefaults() os.Exit(1) } + zap.S().Infof("model type : %v", modelType) + zap.S().Infof("model for image width : %v", imgWidth) + zap.S().Infof("model for image height: %v", imgHeight) + zap.S().Infof("model with horizon : %v", horizon) + client, err := cli.Connect(mqttBroker, username, password, clientId) if err != nil { zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err)) } defer client.Disconnect(50) - p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon) + p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon) defer p.Stop() cli.HandleExit(p) @@ -85,3 +111,33 @@ func main() { zap.L().Fatal("unable to start service", zap.Error(err)) } } + +func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) { + match := modelNameRegex.FindStringSubmatch(modelPath) + + results := map[string]string{} + for i, name := range match { + results[modelNameRegex.SubexpNames()[i]] = name + } + modelType = tools.ParseModelType(results["type"]) + if modelType == tools.ModelTypeUnknown { + err = fmt.Errorf("unknown model type '%v'", results["type"]) + return + } + imgWidth, err = strconv.Atoi(results["imgWidth"]) + if err != nil { + err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err) + return + } + imgHeight, err = strconv.Atoi(results["imgHeight"]) + if err != nil { + err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err) + return + } + horizon, err = strconv.Atoi(results["horizon"]) + if err != nil { + err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err) + return + } + return +} diff --git a/cmd/rc-steering/rc-steering_test.go b/cmd/rc-steering/rc-steering_test.go new file mode 100644 index 0000000..0688723 --- /dev/null +++ b/cmd/rc-steering/rc-steering_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools" + "testing" +) + +func Test_parseModelName(t *testing.T) { + type args struct { + modelPath string + } + tests := []struct { + name string + args args + wantModelType tools.ModelType + wantImgWidth int + wantImgHeight int + wantHorizon int + wantErr bool + }{ + { + name: "categorical", + args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"}, + wantModelType: tools.ModelTypeCategorical, + wantImgWidth: 120, + wantImgHeight: 160, + wantHorizon: 10, + wantErr: false, + }, + { + name: "linear", + args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"}, + wantModelType: tools.ModelTypeLinear, + wantImgWidth: 120, + wantImgHeight: 160, + wantHorizon: 10, + wantErr: false, + }, + { + name: "bad-model", + args: args{modelPath: "/tmp/model_123_120x160h10.tflite"}, + wantModelType: tools.ModelTypeUnknown, + wantImgWidth: 0, + wantImgHeight: 0, + wantHorizon: 0, + wantErr: true, + }, + { + name: "bad-width", + args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"}, + wantModelType: tools.ModelTypeUnknown, + wantImgWidth: 0, + wantImgHeight: 0, + wantHorizon: 0, + wantErr: true, + }, + { + name: "bad-height", + args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"}, + wantModelType: tools.ModelTypeUnknown, + wantImgWidth: 0, + wantImgHeight: 0, + wantHorizon: 0, + wantErr: true, + }, + { + name: "bad-horizon", + args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"}, + wantModelType: tools.ModelTypeUnknown, + wantImgWidth: 0, + wantImgHeight: 0, + wantHorizon: 0, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath) + if (err != nil) != tt.wantErr { + t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotModelType != tt.wantModelType { + t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType) + } + if gotImgWidth != tt.wantImgWidth { + t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth) + } + if gotImgHeight != tt.wantImgHeight { + t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight) + } + if gotHorizon != tt.wantHorizon { + t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon) + } + }) + } +} diff --git a/pkg/steering/steering.go b/pkg/steering/steering.go index 43915ad..c15bebb 100644 --- a/pkg/steering/steering.go +++ b/pkg/steering/steering.go @@ -19,9 +19,10 @@ import ( "time" ) -func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part { +func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part { return &Part{ client: client, + modelType: modelType, modelPath: modelPath, steeringTopic: steeringTopic, cameraTopic: cameraTopic, @@ -42,6 +43,7 @@ type Part struct { options *tflite.InterpreterOptions interpreter *tflite.Interpreter + modelType tools.ModelType modelPath string model *tflite.Model edgeVebosity int @@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) { output := p.interpreter.GetOutputTensor(0).UInt8s() zap.L().Debug("raw steering", zap.Uint8s("result", output)) - steering, score := tools.LinearBin(output, 15, -1, 2.0) + var steering, score float64 + switch p.modelType { + case tools.ModelTypeCategorical: + steering, score = tools.LinearBin(output, 15, -1, 2.0) + case tools.ModelTypeLinear: + steering = 2*(float64(output[0])/255.) - 1. + score = 0.6 + } zap.L().Debug("found steering", zap.Float64("steering", steering), zap.Float64("score", score), diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index 43e9693..60401d9 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -4,6 +4,37 @@ import ( "fmt" "go.uber.org/zap" "sort" + "strings" +) + +type ModelType int + +func ParseModelType(s string) ModelType { + switch strings.ToLower(s) { + case "categorical": + return ModelTypeCategorical + case "linear": + return ModelTypeLinear + default: + return ModelTypeUnknown + } +} + +func (m ModelType) String() string { + switch m { + case ModelTypeCategorical: + return "categorical" + case ModelTypeLinear: + return "linear" + default: + return "unknown" + } +} + +const ( + ModelTypeUnknown ModelType = iota + ModelTypeCategorical + ModelTypeLinear ) // LinearBin perform inverse linear_bin, taking diff --git a/pkg/tools/tools_test.go b/pkg/tools/tools_test.go index f092ff4..45194e6 100644 --- a/pkg/tools/tools_test.go +++ b/pkg/tools/tools_test.go @@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) { }) } } + +func TestParseModelType(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want ModelType + }{ + {name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical}, + {name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical}, + {name: "linear", args: args{s: "linear"}, want: ModelTypeLinear}, + {name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear}, + {name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown}, + {name: "empty", args: args{s: ""}, want: ModelTypeUnknown}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParseModelType(tt.args.s); got != tt.want { + t.Errorf("ParseModelType() = %v, want %v", got, tt.want) + } + }) + } +}