feat: implement linear model and autodetect parameters
This commit is contained in:
parent
1e6966495c
commit
acf6066fea
@ -3,18 +3,26 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"github.com/cyrilix/robocar-base/cli"
|
"github.com/cyrilix/robocar-base/cli"
|
||||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
||||||
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultClientId = "robocar-steering-tflite-edgetpu"
|
DefaultClientId = "robocar-steering-tflite-edgetpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var mqttBroker, username, password, clientId string
|
var mqttBroker, username, password, clientId string
|
||||||
var cameraTopic, steeringTopic string
|
var cameraTopic, steeringTopic string
|
||||||
@ -31,8 +39,8 @@ func main() {
|
|||||||
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
||||||
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
||||||
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
||||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
|
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
|
||||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
|
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
|
||||||
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
||||||
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
@ -63,19 +71,37 @@ func main() {
|
|||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||||
|
}
|
||||||
|
if imgWidth == 0 {
|
||||||
|
imgWidth = width
|
||||||
|
}
|
||||||
|
if imgHeight == 0 {
|
||||||
|
imgHeight = height
|
||||||
|
}
|
||||||
|
if horizonFromName == 0 {
|
||||||
|
horizon = horizonFromName
|
||||||
|
}
|
||||||
if imgWidth <= 0 || imgHeight <= 0 {
|
if imgWidth <= 0 || imgHeight <= 0 {
|
||||||
zap.L().Error("img-width and img-height are mandatory")
|
zap.L().Error("img-width and img-height are mandatory")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
zap.S().Infof("model type : %v", modelType)
|
||||||
|
zap.S().Infof("model for image width : %v", imgWidth)
|
||||||
|
zap.S().Infof("model for image height: %v", imgHeight)
|
||||||
|
zap.S().Infof("model with horizon : %v", horizon)
|
||||||
|
|
||||||
client, err := cli.Connect(mqttBroker, username, password, clientId)
|
client, err := cli.Connect(mqttBroker, username, password, clientId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
|
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer client.Disconnect(50)
|
defer client.Disconnect(50)
|
||||||
|
|
||||||
p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
|
|
||||||
cli.HandleExit(p)
|
cli.HandleExit(p)
|
||||||
@ -85,3 +111,33 @@ func main() {
|
|||||||
zap.L().Fatal("unable to start service", zap.Error(err))
|
zap.L().Fatal("unable to start service", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
|
||||||
|
match := modelNameRegex.FindStringSubmatch(modelPath)
|
||||||
|
|
||||||
|
results := map[string]string{}
|
||||||
|
for i, name := range match {
|
||||||
|
results[modelNameRegex.SubexpNames()[i]] = name
|
||||||
|
}
|
||||||
|
modelType = tools.ParseModelType(results["type"])
|
||||||
|
if modelType == tools.ModelTypeUnknown {
|
||||||
|
err = fmt.Errorf("unknown model type '%v'", results["type"])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
imgWidth, err = strconv.Atoi(results["imgWidth"])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
imgHeight, err = strconv.Atoi(results["imgHeight"])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
horizon, err = strconv.Atoi(results["horizon"])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
97
cmd/rc-steering/rc-steering_test.go
Normal file
97
cmd/rc-steering/rc-steering_test.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseModelName(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
modelPath string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantModelType tools.ModelType
|
||||||
|
wantImgWidth int
|
||||||
|
wantImgHeight int
|
||||||
|
wantHorizon int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "categorical",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeCategorical,
|
||||||
|
wantImgWidth: 120,
|
||||||
|
wantImgHeight: 160,
|
||||||
|
wantHorizon: 10,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linear",
|
||||||
|
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeLinear,
|
||||||
|
wantImgWidth: 120,
|
||||||
|
wantImgHeight: 160,
|
||||||
|
wantHorizon: 10,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-model",
|
||||||
|
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-width",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-height",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-horizon",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotModelType != tt.wantModelType {
|
||||||
|
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
|
||||||
|
}
|
||||||
|
if gotImgWidth != tt.wantImgWidth {
|
||||||
|
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
|
||||||
|
}
|
||||||
|
if gotImgHeight != tt.wantImgHeight {
|
||||||
|
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
|
||||||
|
}
|
||||||
|
if gotHorizon != tt.wantHorizon {
|
||||||
|
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -19,9 +19,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||||
return &Part{
|
return &Part{
|
||||||
client: client,
|
client: client,
|
||||||
|
modelType: modelType,
|
||||||
modelPath: modelPath,
|
modelPath: modelPath,
|
||||||
steeringTopic: steeringTopic,
|
steeringTopic: steeringTopic,
|
||||||
cameraTopic: cameraTopic,
|
cameraTopic: cameraTopic,
|
||||||
@ -42,6 +43,7 @@ type Part struct {
|
|||||||
|
|
||||||
options *tflite.InterpreterOptions
|
options *tflite.InterpreterOptions
|
||||||
interpreter *tflite.Interpreter
|
interpreter *tflite.Interpreter
|
||||||
|
modelType tools.ModelType
|
||||||
modelPath string
|
modelPath string
|
||||||
model *tflite.Model
|
model *tflite.Model
|
||||||
edgeVebosity int
|
edgeVebosity int
|
||||||
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
|||||||
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
||||||
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
||||||
|
|
||||||
steering, score := tools.LinearBin(output, 15, -1, 2.0)
|
var steering, score float64
|
||||||
|
switch p.modelType {
|
||||||
|
case tools.ModelTypeCategorical:
|
||||||
|
steering, score = tools.LinearBin(output, 15, -1, 2.0)
|
||||||
|
case tools.ModelTypeLinear:
|
||||||
|
steering = 2*(float64(output[0])/255.) - 1.
|
||||||
|
score = 0.6
|
||||||
|
}
|
||||||
zap.L().Debug("found steering",
|
zap.L().Debug("found steering",
|
||||||
zap.Float64("steering", steering),
|
zap.Float64("steering", steering),
|
||||||
zap.Float64("score", score),
|
zap.Float64("score", score),
|
||||||
|
@ -4,6 +4,37 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelType int
|
||||||
|
|
||||||
|
func ParseModelType(s string) ModelType {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "categorical":
|
||||||
|
return ModelTypeCategorical
|
||||||
|
case "linear":
|
||||||
|
return ModelTypeLinear
|
||||||
|
default:
|
||||||
|
return ModelTypeUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ModelType) String() string {
|
||||||
|
switch m {
|
||||||
|
case ModelTypeCategorical:
|
||||||
|
return "categorical"
|
||||||
|
case ModelTypeLinear:
|
||||||
|
return "linear"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelTypeUnknown ModelType = iota
|
||||||
|
ModelTypeCategorical
|
||||||
|
ModelTypeLinear
|
||||||
)
|
)
|
||||||
|
|
||||||
// LinearBin perform inverse linear_bin, taking
|
// LinearBin perform inverse linear_bin, taking
|
||||||
|
@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseModelType(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
s string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want ModelType
|
||||||
|
}{
|
||||||
|
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
|
||||||
|
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
|
||||||
|
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
|
||||||
|
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
|
||||||
|
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
|
||||||
|
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ParseModelType(tt.args.s); got != tt.want {
|
||||||
|
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user