feat: implement linear model and autodetect parameters

This commit is contained in:
Cyrille Nofficial 2022-06-10 13:50:22 +02:00
parent 1e6966495c
commit acf6066fea
5 changed files with 223 additions and 5 deletions

View File

@ -3,18 +3,26 @@ package main
import ( import (
"context" "context"
"flag" "flag"
"fmt"
"github.com/cyrilix/robocar-base/cli" "github.com/cyrilix/robocar-base/cli"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics" "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering" "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"go.uber.org/zap" "go.uber.org/zap"
"log" "log"
"os" "os"
"regexp"
"strconv"
) )
const ( const (
DefaultClientId = "robocar-steering-tflite-edgetpu" DefaultClientId = "robocar-steering-tflite-edgetpu"
) )
var (
modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
)
func main() { func main() {
var mqttBroker, username, password, clientId string var mqttBroker, username, password, clientId string
var cameraTopic, steeringTopic string var cameraTopic, steeringTopic string
@ -31,8 +39,8 @@ func main() {
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set") flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set") flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity") flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)") flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)") flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'") flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
flag.Parse() flag.Parse()
@ -63,19 +71,37 @@ func main() {
flag.PrintDefaults() flag.PrintDefaults()
os.Exit(1) os.Exit(1)
} }
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
if err != nil {
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
}
if imgWidth == 0 {
imgWidth = width
}
if imgHeight == 0 {
imgHeight = height
}
if horizonFromName == 0 {
horizon = horizonFromName
}
if imgWidth <= 0 || imgHeight <= 0 { if imgWidth <= 0 || imgHeight <= 0 {
zap.L().Error("img-width and img-height are mandatory") zap.L().Error("img-width and img-height are mandatory")
flag.PrintDefaults() flag.PrintDefaults()
os.Exit(1) os.Exit(1)
} }
zap.S().Infof("model type : %v", modelType)
zap.S().Infof("model for image width : %v", imgWidth)
zap.S().Infof("model for image height: %v", imgHeight)
zap.S().Infof("model with horizon : %v", horizon)
client, err := cli.Connect(mqttBroker, username, password, clientId) client, err := cli.Connect(mqttBroker, username, password, clientId)
if err != nil { if err != nil {
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err)) zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
} }
defer client.Disconnect(50) defer client.Disconnect(50)
p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon) p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
defer p.Stop() defer p.Stop()
cli.HandleExit(p) cli.HandleExit(p)
@ -85,3 +111,33 @@ func main() {
zap.L().Fatal("unable to start service", zap.Error(err)) zap.L().Fatal("unable to start service", zap.Error(err))
} }
} }
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
match := modelNameRegex.FindStringSubmatch(modelPath)
results := map[string]string{}
for i, name := range match {
results[modelNameRegex.SubexpNames()[i]] = name
}
modelType = tools.ParseModelType(results["type"])
if modelType == tools.ModelTypeUnknown {
err = fmt.Errorf("unknown model type '%v'", results["type"])
return
}
imgWidth, err = strconv.Atoi(results["imgWidth"])
if err != nil {
err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
return
}
imgHeight, err = strconv.Atoi(results["imgHeight"])
if err != nil {
err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
return
}
horizon, err = strconv.Atoi(results["horizon"])
if err != nil {
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
return
}
return
}

View File

@ -0,0 +1,97 @@
package main
import (
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"testing"
)
func Test_parseModelName(t *testing.T) {
type args struct {
modelPath string
}
tests := []struct {
name string
args args
wantModelType tools.ModelType
wantImgWidth int
wantImgHeight int
wantHorizon int
wantErr bool
}{
{
name: "categorical",
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
wantModelType: tools.ModelTypeCategorical,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "linear",
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
wantModelType: tools.ModelTypeLinear,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "bad-model",
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-width",
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-height",
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-horizon",
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
if (err != nil) != tt.wantErr {
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotModelType != tt.wantModelType {
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
}
if gotImgWidth != tt.wantImgWidth {
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
}
if gotImgHeight != tt.wantImgHeight {
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
}
if gotHorizon != tt.wantHorizon {
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
}
})
}
}

View File

@ -19,9 +19,10 @@ import (
"time" "time"
) )
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part { func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
return &Part{ return &Part{
client: client, client: client,
modelType: modelType,
modelPath: modelPath, modelPath: modelPath,
steeringTopic: steeringTopic, steeringTopic: steeringTopic,
cameraTopic: cameraTopic, cameraTopic: cameraTopic,
@ -42,6 +43,7 @@ type Part struct {
options *tflite.InterpreterOptions options *tflite.InterpreterOptions
interpreter *tflite.Interpreter interpreter *tflite.Interpreter
modelType tools.ModelType
modelPath string modelPath string
model *tflite.Model model *tflite.Model
edgeVebosity int edgeVebosity int
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
output := p.interpreter.GetOutputTensor(0).UInt8s() output := p.interpreter.GetOutputTensor(0).UInt8s()
zap.L().Debug("raw steering", zap.Uint8s("result", output)) zap.L().Debug("raw steering", zap.Uint8s("result", output))
steering, score := tools.LinearBin(output, 15, -1, 2.0) var steering, score float64
switch p.modelType {
case tools.ModelTypeCategorical:
steering, score = tools.LinearBin(output, 15, -1, 2.0)
case tools.ModelTypeLinear:
steering = 2*(float64(output[0])/255.) - 1.
score = 0.6
}
zap.L().Debug("found steering", zap.L().Debug("found steering",
zap.Float64("steering", steering), zap.Float64("steering", steering),
zap.Float64("score", score), zap.Float64("score", score),

View File

@ -4,6 +4,37 @@ import (
"fmt" "fmt"
"go.uber.org/zap" "go.uber.org/zap"
"sort" "sort"
"strings"
)
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
) )
// LinearBin perform inverse linear_bin, taking // LinearBin perform inverse linear_bin, taking

View File

@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
}) })
} }
} }
func TestParseModelType(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want ModelType
}{
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseModelType(tt.args.s); got != tt.want {
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
}
})
}
}