feat: implement linear model and autodetect parameters
This commit is contained in:
		| @@ -3,18 +3,26 @@ package main | ||||
| import ( | ||||
| 	"context" | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"github.com/cyrilix/robocar-base/cli" | ||||
| 	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics" | ||||
| 	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering" | ||||
| 	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools" | ||||
| 	"go.uber.org/zap" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	DefaultClientId = "robocar-steering-tflite-edgetpu" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$") | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	var mqttBroker, username, password, clientId string | ||||
| 	var cameraTopic, steeringTopic string | ||||
| @@ -31,8 +39,8 @@ func main() { | ||||
| 	flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set") | ||||
| 	flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set") | ||||
| 	flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity") | ||||
| 	flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)") | ||||
| 	flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)") | ||||
| 	flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model") | ||||
| 	flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model") | ||||
| 	flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'") | ||||
| 	logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") | ||||
| 	flag.Parse() | ||||
| @@ -63,19 +71,37 @@ func main() { | ||||
| 		flag.PrintDefaults() | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| 	modelType, width, height, horizonFromName, err := parseModelName(modelPath) | ||||
| 	if err != nil { | ||||
| 		zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err) | ||||
| 	} | ||||
| 	if imgWidth == 0 { | ||||
| 		imgWidth = width | ||||
| 	} | ||||
| 	if imgHeight == 0 { | ||||
| 		imgHeight = height | ||||
| 	} | ||||
| 	if horizonFromName == 0 { | ||||
| 		horizon = horizonFromName | ||||
| 	} | ||||
| 	if imgWidth <= 0 || imgHeight <= 0 { | ||||
| 		zap.L().Error("img-width and img-height are mandatory") | ||||
| 		flag.PrintDefaults() | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	zap.S().Infof("model type            : %v", modelType) | ||||
| 	zap.S().Infof("model for image width : %v", imgWidth) | ||||
| 	zap.S().Infof("model for image height: %v", imgHeight) | ||||
| 	zap.S().Infof("model with horizon    : %v", horizon) | ||||
|  | ||||
| 	client, err := cli.Connect(mqttBroker, username, password, clientId) | ||||
| 	if err != nil { | ||||
| 		zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err)) | ||||
| 	} | ||||
| 	defer client.Disconnect(50) | ||||
|  | ||||
| 	p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon) | ||||
| 	p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon) | ||||
| 	defer p.Stop() | ||||
|  | ||||
| 	cli.HandleExit(p) | ||||
| @@ -85,3 +111,33 @@ func main() { | ||||
| 		zap.L().Fatal("unable to start service", zap.Error(err)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) { | ||||
| 	match := modelNameRegex.FindStringSubmatch(modelPath) | ||||
|  | ||||
| 	results := map[string]string{} | ||||
| 	for i, name := range match { | ||||
| 		results[modelNameRegex.SubexpNames()[i]] = name | ||||
| 	} | ||||
| 	modelType = tools.ParseModelType(results["type"]) | ||||
| 	if modelType == tools.ModelTypeUnknown { | ||||
| 		err = fmt.Errorf("unknown model type '%v'", results["type"]) | ||||
| 		return | ||||
| 	} | ||||
| 	imgWidth, err = strconv.Atoi(results["imgWidth"]) | ||||
| 	if err != nil { | ||||
| 		err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err) | ||||
| 		return | ||||
| 	} | ||||
| 	imgHeight, err = strconv.Atoi(results["imgHeight"]) | ||||
| 	if err != nil { | ||||
| 		err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err) | ||||
| 		return | ||||
| 	} | ||||
| 	horizon, err = strconv.Atoi(results["horizon"]) | ||||
| 	if err != nil { | ||||
| 		err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err) | ||||
| 		return | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|   | ||||
							
								
								
									
										97
									
								
								cmd/rc-steering/rc-steering_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								cmd/rc-steering/rc-steering_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func Test_parseModelName(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		modelPath string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name          string | ||||
| 		args          args | ||||
| 		wantModelType tools.ModelType | ||||
| 		wantImgWidth  int | ||||
| 		wantImgHeight int | ||||
| 		wantHorizon   int | ||||
| 		wantErr       bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:          "categorical", | ||||
| 			args:          args{modelPath: "/tmp/model_categorical_120x160h10.tflite"}, | ||||
| 			wantModelType: tools.ModelTypeCategorical, | ||||
| 			wantImgWidth:  120, | ||||
| 			wantImgHeight: 160, | ||||
| 			wantHorizon:   10, | ||||
| 			wantErr:       false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "linear", | ||||
| 			args:          args{modelPath: "/tmp/model_linear_120x160h10.tflite"}, | ||||
| 			wantModelType: tools.ModelTypeLinear, | ||||
| 			wantImgWidth:  120, | ||||
| 			wantImgHeight: 160, | ||||
| 			wantHorizon:   10, | ||||
| 			wantErr:       false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "bad-model", | ||||
| 			args:          args{modelPath: "/tmp/model_123_120x160h10.tflite"}, | ||||
| 			wantModelType: tools.ModelTypeUnknown, | ||||
| 			wantImgWidth:  0, | ||||
| 			wantImgHeight: 0, | ||||
| 			wantHorizon:   0, | ||||
| 			wantErr:       true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "bad-width", | ||||
| 			args:          args{modelPath: "/tmp/model_categorical_ax160h10.tflite"}, | ||||
| 			wantModelType: tools.ModelTypeUnknown, | ||||
| 			wantImgWidth:  0, | ||||
| 			wantImgHeight: 0, | ||||
| 			wantHorizon:   0, | ||||
| 			wantErr:       true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "bad-height", | ||||
| 			args:          args{modelPath: "/tmp/model_categorical_120xh10.tflite"}, | ||||
| 			wantModelType: tools.ModelTypeUnknown, | ||||
| 			wantImgWidth:  0, | ||||
| 			wantImgHeight: 0, | ||||
| 			wantHorizon:   0, | ||||
| 			wantErr:       true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "bad-horizon", | ||||
| 			args:          args{modelPath: "/tmp/model_categorical_120x160h.tflite"}, | ||||
| 			wantModelType: tools.ModelTypeUnknown, | ||||
| 			wantImgWidth:  0, | ||||
| 			wantImgHeight: 0, | ||||
| 			wantHorizon:   0, | ||||
| 			wantErr:       true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 				return | ||||
| 			} | ||||
| 			if gotModelType != tt.wantModelType { | ||||
| 				t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType) | ||||
| 			} | ||||
| 			if gotImgWidth != tt.wantImgWidth { | ||||
| 				t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth) | ||||
| 			} | ||||
| 			if gotImgHeight != tt.wantImgHeight { | ||||
| 				t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight) | ||||
| 			} | ||||
| 			if gotHorizon != tt.wantHorizon { | ||||
| 				t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @@ -19,9 +19,10 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part { | ||||
| func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part { | ||||
| 	return &Part{ | ||||
| 		client:        client, | ||||
| 		modelType:     modelType, | ||||
| 		modelPath:     modelPath, | ||||
| 		steeringTopic: steeringTopic, | ||||
| 		cameraTopic:   cameraTopic, | ||||
| @@ -42,6 +43,7 @@ type Part struct { | ||||
|  | ||||
| 	options      *tflite.InterpreterOptions | ||||
| 	interpreter  *tflite.Interpreter | ||||
| 	modelType    tools.ModelType | ||||
| 	modelPath    string | ||||
| 	model        *tflite.Model | ||||
| 	edgeVebosity int | ||||
| @@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) { | ||||
| 	output := p.interpreter.GetOutputTensor(0).UInt8s() | ||||
| 	zap.L().Debug("raw steering", zap.Uint8s("result", output)) | ||||
|  | ||||
| 	steering, score := tools.LinearBin(output, 15, -1, 2.0) | ||||
| 	var steering, score float64 | ||||
| 	switch p.modelType { | ||||
| 	case tools.ModelTypeCategorical: | ||||
| 		steering, score = tools.LinearBin(output, 15, -1, 2.0) | ||||
| 	case tools.ModelTypeLinear: | ||||
| 		steering = 2*(float64(output[0])/255.) - 1. | ||||
| 		score = 0.6 | ||||
| 	} | ||||
| 	zap.L().Debug("found steering", | ||||
| 		zap.Float64("steering", steering), | ||||
| 		zap.Float64("score", score), | ||||
|   | ||||
| @@ -4,6 +4,37 @@ import ( | ||||
| 	"fmt" | ||||
| 	"go.uber.org/zap" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type ModelType int | ||||
|  | ||||
| func ParseModelType(s string) ModelType { | ||||
| 	switch strings.ToLower(s) { | ||||
| 	case "categorical": | ||||
| 		return ModelTypeCategorical | ||||
| 	case "linear": | ||||
| 		return ModelTypeLinear | ||||
| 	default: | ||||
| 		return ModelTypeUnknown | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (m ModelType) String() string { | ||||
| 	switch m { | ||||
| 	case ModelTypeCategorical: | ||||
| 		return "categorical" | ||||
| 	case ModelTypeLinear: | ||||
| 		return "linear" | ||||
| 	default: | ||||
| 		return "unknown" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	ModelTypeUnknown ModelType = iota | ||||
| 	ModelTypeCategorical | ||||
| 	ModelTypeLinear | ||||
| ) | ||||
|  | ||||
| // LinearBin  perform inverse linear_bin, taking | ||||
|   | ||||
| @@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestParseModelType(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		s string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name string | ||||
| 		args args | ||||
| 		want ModelType | ||||
| 	}{ | ||||
| 		{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical}, | ||||
| 		{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical}, | ||||
| 		{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear}, | ||||
| 		{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear}, | ||||
| 		{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown}, | ||||
| 		{name: "empty", args: args{s: ""}, want: ModelTypeUnknown}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if got := ParseModelType(tt.args.s); got != tt.want { | ||||
| 				t.Errorf("ParseModelType() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user