feat: implement linear model and autodetect parameters
This commit is contained in:
@ -19,9 +19,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||
return &Part{
|
||||
client: client,
|
||||
modelType: modelType,
|
||||
modelPath: modelPath,
|
||||
steeringTopic: steeringTopic,
|
||||
cameraTopic: cameraTopic,
|
||||
@ -42,6 +43,7 @@ type Part struct {
|
||||
|
||||
options *tflite.InterpreterOptions
|
||||
interpreter *tflite.Interpreter
|
||||
modelType tools.ModelType
|
||||
modelPath string
|
||||
model *tflite.Model
|
||||
edgeVebosity int
|
||||
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
||||
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
||||
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
||||
|
||||
steering, score := tools.LinearBin(output, 15, -1, 2.0)
|
||||
var steering, score float64
|
||||
switch p.modelType {
|
||||
case tools.ModelTypeCategorical:
|
||||
steering, score = tools.LinearBin(output, 15, -1, 2.0)
|
||||
case tools.ModelTypeLinear:
|
||||
steering = 2*(float64(output[0])/255.) - 1.
|
||||
score = 0.6
|
||||
}
|
||||
zap.L().Debug("found steering",
|
||||
zap.Float64("steering", steering),
|
||||
zap.Float64("score", score),
|
||||
|
@ -4,6 +4,37 @@ import (
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelType int
|
||||
|
||||
func ParseModelType(s string) ModelType {
|
||||
switch strings.ToLower(s) {
|
||||
case "categorical":
|
||||
return ModelTypeCategorical
|
||||
case "linear":
|
||||
return ModelTypeLinear
|
||||
default:
|
||||
return ModelTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (m ModelType) String() string {
|
||||
switch m {
|
||||
case ModelTypeCategorical:
|
||||
return "categorical"
|
||||
case ModelTypeLinear:
|
||||
return "linear"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ModelTypeUnknown ModelType = iota
|
||||
ModelTypeCategorical
|
||||
ModelTypeLinear
|
||||
)
|
||||
|
||||
// LinearBin perform inverse linear_bin, taking
|
||||
|
@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelType(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want ModelType
|
||||
}{
|
||||
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
|
||||
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
|
||||
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
|
||||
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
|
||||
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
|
||||
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ParseModelType(tt.args.s); got != tt.want {
|
||||
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user