feat: implement linear model and autodetect parameters

This commit is contained in:
2022-06-10 13:50:22 +02:00
parent 1e6966495c
commit acf6066fea
5 changed files with 223 additions and 5 deletions

View File

@ -19,9 +19,10 @@ import (
"time"
)
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
return &Part{
client: client,
modelType: modelType,
modelPath: modelPath,
steeringTopic: steeringTopic,
cameraTopic: cameraTopic,
@ -42,6 +43,7 @@ type Part struct {
options *tflite.InterpreterOptions
interpreter *tflite.Interpreter
modelType tools.ModelType
modelPath string
model *tflite.Model
edgeVebosity int
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
output := p.interpreter.GetOutputTensor(0).UInt8s()
zap.L().Debug("raw steering", zap.Uint8s("result", output))
steering, score := tools.LinearBin(output, 15, -1, 2.0)
var steering, score float64
switch p.modelType {
case tools.ModelTypeCategorical:
steering, score = tools.LinearBin(output, 15, -1, 2.0)
case tools.ModelTypeLinear:
steering = 2*(float64(output[0])/255.) - 1.
score = 0.6
}
zap.L().Debug("found steering",
zap.Float64("steering", steering),
zap.Float64("score", score),

View File

@ -4,6 +4,37 @@ import (
"fmt"
"go.uber.org/zap"
"sort"
"strings"
)
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
// LinearBin perform inverse linear_bin, taking

View File

@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
})
}
}
func TestParseModelType(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want ModelType
}{
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseModelType(tt.args.s); got != tt.want {
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
}
})
}
}