diff --git a/pkg/steering/steering.go b/pkg/steering/steering.go index a243d63..c3e4a34 100644 --- a/pkg/steering/steering.go +++ b/pkg/steering/steering.go @@ -7,6 +7,7 @@ import ( "github.com/cyrilix/robocar-base/service" "github.com/cyrilix/robocar-protobuf/go/events" "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics" + "github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools" "github.com/disintegration/imaging" mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/golang/protobuf/proto" @@ -15,7 +16,6 @@ import ( "go.uber.org/zap" "image" _ "image/jpeg" - "sort" "time" ) @@ -212,43 +212,21 @@ func (p *Part) Value(img image.Image) (float32, float32, error) { } output := p.interpreter.GetOutputTensor(0) + outputSize := output.Dim(output.NumDims() - 1) + b := make([]byte, outputSize) - type result struct { - score float64 - index int - } status = output.CopyToBuffer(&b[0]) if status != tflite.OK { return 0., 0., fmt.Errorf("output failed: %v", status) } - var results []result - minScore := 0.2 - for i := 0; i < outputSize; i++ { - score := float64(b[i]) / 255.0 - if score < minScore { - continue - } - results = append(results, result{score: score, index: i}) - } - - if len(results) == 0 { - zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore)) - return 0., 0., nil - } - zap.S().Debugf("raw result: %v", results) - - sort.Slice(results, func(i, j int) bool { - return results[i].score > results[j].score - }) - - steering := float64(results[0].index)*(2./float64(outputSize)) - 1 + steering, score := tools.LinearBin(b, 15, -1, 2.0) zap.L().Debug("found steering", zap.Float64("steering", steering), - zap.Float64("score", results[0].score), + zap.Float64("score", score), ) - return float32(steering), float32(results[0].score), nil + return float32(steering), float32(score), nil } var registerCallbacks = func(p *Part) error { diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go new file mode 100644 index 0000000..ca6697a --- /dev/null +++ b/pkg/tools/tools.go @@ -0,0 +1,40 @@ +package tools + +import ( + "fmt" + "go.uber.org/zap" + "sort" +) + +// LinearBin perform inverse linear_bin, taking +func LinearBin(arr []byte, n int, offset int, r float64) (float64, float64) { + outputSize := len(arr) + type result struct { + score float64 + index int + } + + var results []result + minScore := 0.2 + for i := 0; i < outputSize; i++ { + score := float64(arr[i]) / 255.0 + if score < minScore { + continue + } + results = append(results, result{score: score, index: i}) + } + + if len(results) == 0 { + zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore)) + return 0., 0. + } + zap.S().Debugf("raw result: %v", results) + + sort.Slice(results, func(i, j int) bool { + return results[i].score > results[j].score + }) + + b := results[0].index + a := float64(b)*(r/(float64(n)+float64(offset))) + float64(offset) + return a, results[0].score +} diff --git a/pkg/tools/tools_test.go b/pkg/tools/tools_test.go new file mode 100644 index 0000000..f092ff4 --- /dev/null +++ b/pkg/tools/tools_test.go @@ -0,0 +1,77 @@ +package tools + +import ( + "testing" +) + +func Test_LinearBin(t *testing.T) { + type args struct { + arr []byte + n int + offset int + r float64 + } + tests := []struct { + name string + args args + want float64 + want1 float64 + }{ + { + name: "default", + args: args{ + arr: []byte{0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0}, + n: 15, + offset: -1, + r: 2.0, + }, + want: 0., + want1: 1.0, + }, + { + name: "left", + args: args{ + arr: []byte{255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + n: 15, + offset: -1, + r: 2.0, + }, + want: -1., + want1: 1.0, + }, + { + name: "right", + args: args{ + arr: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255}, + n: 15, + offset: -1, + r: 2.0, + }, + want: 1., + want1: 1.0, + }, + { + name: "right", + args: args{ + arr: []byte{0, 0, 0, 0, 0, 0, 0, 5, 10, 15, 20, 40, 100, 60, 5}, + n: 15, + offset: -1, + r: 2.0, + }, + want: 0.7142857142857142, + want1: 0.39215686274509803, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := LinearBin(tt.args.arr, tt.args.n, tt.args.offset, tt.args.r) + if got != tt.want { + t.Errorf("linearBin() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("linearBin() got1 = %v, want %v", got1, tt.want1) + } + }) + } +}