fix: bad inference result conversion
This commit is contained in:
parent
05a0af6b74
commit
c5c87da714
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/cyrilix/robocar-base/service"
|
"github.com/cyrilix/robocar-base/service"
|
||||||
"github.com/cyrilix/robocar-protobuf/go/events"
|
"github.com/cyrilix/robocar-protobuf/go/events"
|
||||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||||
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||||
"github.com/disintegration/imaging"
|
"github.com/disintegration/imaging"
|
||||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
@ -15,7 +16,6 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"image"
|
"image"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
"sort"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -212,43 +212,21 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
output := p.interpreter.GetOutputTensor(0)
|
output := p.interpreter.GetOutputTensor(0)
|
||||||
|
|
||||||
outputSize := output.Dim(output.NumDims() - 1)
|
outputSize := output.Dim(output.NumDims() - 1)
|
||||||
|
|
||||||
b := make([]byte, outputSize)
|
b := make([]byte, outputSize)
|
||||||
type result struct {
|
|
||||||
score float64
|
|
||||||
index int
|
|
||||||
}
|
|
||||||
status = output.CopyToBuffer(&b[0])
|
status = output.CopyToBuffer(&b[0])
|
||||||
if status != tflite.OK {
|
if status != tflite.OK {
|
||||||
return 0., 0., fmt.Errorf("output failed: %v", status)
|
return 0., 0., fmt.Errorf("output failed: %v", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
var results []result
|
steering, score := tools.LinearBin(b, 15, -1, 2.0)
|
||||||
minScore := 0.2
|
|
||||||
for i := 0; i < outputSize; i++ {
|
|
||||||
score := float64(b[i]) / 255.0
|
|
||||||
if score < minScore {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
results = append(results, result{score: score, index: i})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(results) == 0 {
|
|
||||||
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
|
|
||||||
return 0., 0., nil
|
|
||||||
}
|
|
||||||
zap.S().Debugf("raw result: %v", results)
|
|
||||||
|
|
||||||
sort.Slice(results, func(i, j int) bool {
|
|
||||||
return results[i].score > results[j].score
|
|
||||||
})
|
|
||||||
|
|
||||||
steering := float64(results[0].index)*(2./float64(outputSize)) - 1
|
|
||||||
zap.L().Debug("found steering",
|
zap.L().Debug("found steering",
|
||||||
zap.Float64("steering", steering),
|
zap.Float64("steering", steering),
|
||||||
zap.Float64("score", results[0].score),
|
zap.Float64("score", score),
|
||||||
)
|
)
|
||||||
return float32(steering), float32(results[0].score), nil
|
return float32(steering), float32(score), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var registerCallbacks = func(p *Part) error {
|
var registerCallbacks = func(p *Part) error {
|
||||||
|
40
pkg/tools/tools.go
Normal file
40
pkg/tools/tools.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LinearBin perform inverse linear_bin, taking
|
||||||
|
func LinearBin(arr []byte, n int, offset int, r float64) (float64, float64) {
|
||||||
|
outputSize := len(arr)
|
||||||
|
type result struct {
|
||||||
|
score float64
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []result
|
||||||
|
minScore := 0.2
|
||||||
|
for i := 0; i < outputSize; i++ {
|
||||||
|
score := float64(arr[i]) / 255.0
|
||||||
|
if score < minScore {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results = append(results, result{score: score, index: i})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
|
||||||
|
return 0., 0.
|
||||||
|
}
|
||||||
|
zap.S().Debugf("raw result: %v", results)
|
||||||
|
|
||||||
|
sort.Slice(results, func(i, j int) bool {
|
||||||
|
return results[i].score > results[j].score
|
||||||
|
})
|
||||||
|
|
||||||
|
b := results[0].index
|
||||||
|
a := float64(b)*(r/(float64(n)+float64(offset))) + float64(offset)
|
||||||
|
return a, results[0].score
|
||||||
|
}
|
77
pkg/tools/tools_test.go
Normal file
77
pkg/tools/tools_test.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_LinearBin(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
arr []byte
|
||||||
|
n int
|
||||||
|
offset int
|
||||||
|
r float64
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want float64
|
||||||
|
want1 float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
args: args{
|
||||||
|
arr: []byte{0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
n: 15,
|
||||||
|
offset: -1,
|
||||||
|
r: 2.0,
|
||||||
|
},
|
||||||
|
want: 0.,
|
||||||
|
want1: 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "left",
|
||||||
|
args: args{
|
||||||
|
arr: []byte{255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
n: 15,
|
||||||
|
offset: -1,
|
||||||
|
r: 2.0,
|
||||||
|
},
|
||||||
|
want: -1.,
|
||||||
|
want1: 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right",
|
||||||
|
args: args{
|
||||||
|
arr: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255},
|
||||||
|
n: 15,
|
||||||
|
offset: -1,
|
||||||
|
r: 2.0,
|
||||||
|
},
|
||||||
|
want: 1.,
|
||||||
|
want1: 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right",
|
||||||
|
args: args{
|
||||||
|
arr: []byte{0, 0, 0, 0, 0, 0, 0, 5, 10, 15, 20, 40, 100, 60, 5},
|
||||||
|
n: 15,
|
||||||
|
offset: -1,
|
||||||
|
r: 2.0,
|
||||||
|
},
|
||||||
|
want: 0.7142857142857142,
|
||||||
|
want1: 0.39215686274509803,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, got1 := LinearBin(tt.args.arr, tt.args.n, tt.args.offset, tt.args.r)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("linearBin() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("linearBin() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user