feat: implement linear model and autodetect parameters
This commit is contained in:
parent
1e6966495c
commit
acf6066fea
@ -3,18 +3,26 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/cyrilix/robocar-base/cli"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||
"go.uber.org/zap"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultClientId = "robocar-steering-tflite-edgetpu"
|
||||
)
|
||||
|
||||
var (
|
||||
modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
|
||||
)
|
||||
|
||||
func main() {
|
||||
var mqttBroker, username, password, clientId string
|
||||
var cameraTopic, steeringTopic string
|
||||
@ -31,8 +39,8 @@ func main() {
|
||||
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
||||
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
||||
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
|
||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
|
||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
|
||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
|
||||
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
||||
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
||||
flag.Parse()
|
||||
@ -63,19 +71,37 @@ func main() {
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
|
||||
if err != nil {
|
||||
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||
}
|
||||
if imgWidth == 0 {
|
||||
imgWidth = width
|
||||
}
|
||||
if imgHeight == 0 {
|
||||
imgHeight = height
|
||||
}
|
||||
if horizonFromName == 0 {
|
||||
horizon = horizonFromName
|
||||
}
|
||||
if imgWidth <= 0 || imgHeight <= 0 {
|
||||
zap.L().Error("img-width and img-height are mandatory")
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
zap.S().Infof("model type : %v", modelType)
|
||||
zap.S().Infof("model for image width : %v", imgWidth)
|
||||
zap.S().Infof("model for image height: %v", imgHeight)
|
||||
zap.S().Infof("model with horizon : %v", horizon)
|
||||
|
||||
client, err := cli.Connect(mqttBroker, username, password, clientId)
|
||||
if err != nil {
|
||||
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
|
||||
}
|
||||
defer client.Disconnect(50)
|
||||
|
||||
p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
||||
p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
||||
defer p.Stop()
|
||||
|
||||
cli.HandleExit(p)
|
||||
@ -85,3 +111,33 @@ func main() {
|
||||
zap.L().Fatal("unable to start service", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
|
||||
match := modelNameRegex.FindStringSubmatch(modelPath)
|
||||
|
||||
results := map[string]string{}
|
||||
for i, name := range match {
|
||||
results[modelNameRegex.SubexpNames()[i]] = name
|
||||
}
|
||||
modelType = tools.ParseModelType(results["type"])
|
||||
if modelType == tools.ModelTypeUnknown {
|
||||
err = fmt.Errorf("unknown model type '%v'", results["type"])
|
||||
return
|
||||
}
|
||||
imgWidth, err = strconv.Atoi(results["imgWidth"])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
|
||||
return
|
||||
}
|
||||
imgHeight, err = strconv.Atoi(results["imgHeight"])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
|
||||
return
|
||||
}
|
||||
horizon, err = strconv.Atoi(results["horizon"])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
97
cmd/rc-steering/rc-steering_test.go
Normal file
97
cmd/rc-steering/rc-steering_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_parseModelName(t *testing.T) {
|
||||
type args struct {
|
||||
modelPath string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantModelType tools.ModelType
|
||||
wantImgWidth int
|
||||
wantImgHeight int
|
||||
wantHorizon int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "categorical",
|
||||
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeCategorical,
|
||||
wantImgWidth: 120,
|
||||
wantImgHeight: 160,
|
||||
wantHorizon: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "linear",
|
||||
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeLinear,
|
||||
wantImgWidth: 120,
|
||||
wantImgHeight: 160,
|
||||
wantHorizon: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bad-model",
|
||||
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-width",
|
||||
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-height",
|
||||
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-horizon",
|
||||
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotModelType != tt.wantModelType {
|
||||
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
|
||||
}
|
||||
if gotImgWidth != tt.wantImgWidth {
|
||||
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
|
||||
}
|
||||
if gotImgHeight != tt.wantImgHeight {
|
||||
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
|
||||
}
|
||||
if gotHorizon != tt.wantHorizon {
|
||||
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -19,9 +19,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||
return &Part{
|
||||
client: client,
|
||||
modelType: modelType,
|
||||
modelPath: modelPath,
|
||||
steeringTopic: steeringTopic,
|
||||
cameraTopic: cameraTopic,
|
||||
@ -42,6 +43,7 @@ type Part struct {
|
||||
|
||||
options *tflite.InterpreterOptions
|
||||
interpreter *tflite.Interpreter
|
||||
modelType tools.ModelType
|
||||
modelPath string
|
||||
model *tflite.Model
|
||||
edgeVebosity int
|
||||
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
||||
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
||||
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
||||
|
||||
steering, score := tools.LinearBin(output, 15, -1, 2.0)
|
||||
var steering, score float64
|
||||
switch p.modelType {
|
||||
case tools.ModelTypeCategorical:
|
||||
steering, score = tools.LinearBin(output, 15, -1, 2.0)
|
||||
case tools.ModelTypeLinear:
|
||||
steering = 2*(float64(output[0])/255.) - 1.
|
||||
score = 0.6
|
||||
}
|
||||
zap.L().Debug("found steering",
|
||||
zap.Float64("steering", steering),
|
||||
zap.Float64("score", score),
|
||||
|
@ -4,6 +4,37 @@ import (
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelType int
|
||||
|
||||
func ParseModelType(s string) ModelType {
|
||||
switch strings.ToLower(s) {
|
||||
case "categorical":
|
||||
return ModelTypeCategorical
|
||||
case "linear":
|
||||
return ModelTypeLinear
|
||||
default:
|
||||
return ModelTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (m ModelType) String() string {
|
||||
switch m {
|
||||
case ModelTypeCategorical:
|
||||
return "categorical"
|
||||
case ModelTypeLinear:
|
||||
return "linear"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ModelTypeUnknown ModelType = iota
|
||||
ModelTypeCategorical
|
||||
ModelTypeLinear
|
||||
)
|
||||
|
||||
// LinearBin perform inverse linear_bin, taking
|
||||
|
@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelType(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want ModelType
|
||||
}{
|
||||
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
|
||||
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
|
||||
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
|
||||
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
|
||||
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
|
||||
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ParseModelType(tt.args.s); got != tt.want {
|
||||
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user