feat: implement linear model and autodetect parameters
This commit is contained in:
		@@ -3,18 +3,26 @@ package main
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"flag"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/cyrilix/robocar-base/cli"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DefaultClientId = "robocar-steering-tflite-edgetpu"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	var mqttBroker, username, password, clientId string
 | 
			
		||||
	var cameraTopic, steeringTopic string
 | 
			
		||||
@@ -31,8 +39,8 @@ func main() {
 | 
			
		||||
	flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
 | 
			
		||||
	flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
 | 
			
		||||
	flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
 | 
			
		||||
	flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
 | 
			
		||||
	flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
 | 
			
		||||
	flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
 | 
			
		||||
	flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
 | 
			
		||||
	flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
 | 
			
		||||
	logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
@@ -63,19 +71,37 @@ func main() {
 | 
			
		||||
		flag.PrintDefaults()
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
	modelType, width, height, horizonFromName, err := parseModelName(modelPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
 | 
			
		||||
	}
 | 
			
		||||
	if imgWidth == 0 {
 | 
			
		||||
		imgWidth = width
 | 
			
		||||
	}
 | 
			
		||||
	if imgHeight == 0 {
 | 
			
		||||
		imgHeight = height
 | 
			
		||||
	}
 | 
			
		||||
	if horizonFromName == 0 {
 | 
			
		||||
		horizon = horizonFromName
 | 
			
		||||
	}
 | 
			
		||||
	if imgWidth <= 0 || imgHeight <= 0 {
 | 
			
		||||
		zap.L().Error("img-width and img-height are mandatory")
 | 
			
		||||
		flag.PrintDefaults()
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	zap.S().Infof("model type            : %v", modelType)
 | 
			
		||||
	zap.S().Infof("model for image width : %v", imgWidth)
 | 
			
		||||
	zap.S().Infof("model for image height: %v", imgHeight)
 | 
			
		||||
	zap.S().Infof("model with horizon    : %v", horizon)
 | 
			
		||||
 | 
			
		||||
	client, err := cli.Connect(mqttBroker, username, password, clientId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
 | 
			
		||||
	}
 | 
			
		||||
	defer client.Disconnect(50)
 | 
			
		||||
 | 
			
		||||
	p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
 | 
			
		||||
	p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
 | 
			
		||||
	defer p.Stop()
 | 
			
		||||
 | 
			
		||||
	cli.HandleExit(p)
 | 
			
		||||
@@ -85,3 +111,33 @@ func main() {
 | 
			
		||||
		zap.L().Fatal("unable to start service", zap.Error(err))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
 | 
			
		||||
	match := modelNameRegex.FindStringSubmatch(modelPath)
 | 
			
		||||
 | 
			
		||||
	results := map[string]string{}
 | 
			
		||||
	for i, name := range match {
 | 
			
		||||
		results[modelNameRegex.SubexpNames()[i]] = name
 | 
			
		||||
	}
 | 
			
		||||
	modelType = tools.ParseModelType(results["type"])
 | 
			
		||||
	if modelType == tools.ModelTypeUnknown {
 | 
			
		||||
		err = fmt.Errorf("unknown model type '%v'", results["type"])
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	imgWidth, err = strconv.Atoi(results["imgWidth"])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	imgHeight, err = strconv.Atoi(results["imgHeight"])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	horizon, err = strconv.Atoi(results["horizon"])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										97
									
								
								cmd/rc-steering/rc-steering_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								cmd/rc-steering/rc-steering_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_parseModelName(t *testing.T) {
 | 
			
		||||
	type args struct {
 | 
			
		||||
		modelPath string
 | 
			
		||||
	}
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		args          args
 | 
			
		||||
		wantModelType tools.ModelType
 | 
			
		||||
		wantImgWidth  int
 | 
			
		||||
		wantImgHeight int
 | 
			
		||||
		wantHorizon   int
 | 
			
		||||
		wantErr       bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:          "categorical",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeCategorical,
 | 
			
		||||
			wantImgWidth:  120,
 | 
			
		||||
			wantImgHeight: 160,
 | 
			
		||||
			wantHorizon:   10,
 | 
			
		||||
			wantErr:       false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "linear",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeLinear,
 | 
			
		||||
			wantImgWidth:  120,
 | 
			
		||||
			wantImgHeight: 160,
 | 
			
		||||
			wantHorizon:   10,
 | 
			
		||||
			wantErr:       false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "bad-model",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_123_120x160h10.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeUnknown,
 | 
			
		||||
			wantImgWidth:  0,
 | 
			
		||||
			wantImgHeight: 0,
 | 
			
		||||
			wantHorizon:   0,
 | 
			
		||||
			wantErr:       true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "bad-width",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeUnknown,
 | 
			
		||||
			wantImgWidth:  0,
 | 
			
		||||
			wantImgHeight: 0,
 | 
			
		||||
			wantHorizon:   0,
 | 
			
		||||
			wantErr:       true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "bad-height",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeUnknown,
 | 
			
		||||
			wantImgWidth:  0,
 | 
			
		||||
			wantImgHeight: 0,
 | 
			
		||||
			wantHorizon:   0,
 | 
			
		||||
			wantErr:       true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "bad-horizon",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeUnknown,
 | 
			
		||||
			wantImgWidth:  0,
 | 
			
		||||
			wantImgHeight: 0,
 | 
			
		||||
			wantHorizon:   0,
 | 
			
		||||
			wantErr:       true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
 | 
			
		||||
			if (err != nil) != tt.wantErr {
 | 
			
		||||
				t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if gotModelType != tt.wantModelType {
 | 
			
		||||
				t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
 | 
			
		||||
			}
 | 
			
		||||
			if gotImgWidth != tt.wantImgWidth {
 | 
			
		||||
				t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
 | 
			
		||||
			}
 | 
			
		||||
			if gotImgHeight != tt.wantImgHeight {
 | 
			
		||||
				t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
 | 
			
		||||
			}
 | 
			
		||||
			if gotHorizon != tt.wantHorizon {
 | 
			
		||||
				t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user