fix: bad inference result conversion

This commit is contained in:
2021-12-28 18:00:48 +01:00
parent 05a0af6b74
commit c5c87da714
3 changed files with 123 additions and 28 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/cyrilix/robocar-base/service"
"github.com/cyrilix/robocar-protobuf/go/events"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"github.com/disintegration/imaging"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/golang/protobuf/proto"
@ -15,7 +16,6 @@ import (
"go.uber.org/zap"
"image"
_ "image/jpeg"
"sort"
"time"
)
@ -212,43 +212,21 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
}
output := p.interpreter.GetOutputTensor(0)
outputSize := output.Dim(output.NumDims() - 1)
b := make([]byte, outputSize)
type result struct {
score float64
index int
}
status = output.CopyToBuffer(&b[0])
if status != tflite.OK {
return 0., 0., fmt.Errorf("output failed: %v", status)
}
var results []result
minScore := 0.2
for i := 0; i < outputSize; i++ {
score := float64(b[i]) / 255.0
if score < minScore {
continue
}
results = append(results, result{score: score, index: i})
}
if len(results) == 0 {
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
return 0., 0., nil
}
zap.S().Debugf("raw result: %v", results)
sort.Slice(results, func(i, j int) bool {
return results[i].score > results[j].score
})
steering := float64(results[0].index)*(2./float64(outputSize)) - 1
steering, score := tools.LinearBin(b, 15, -1, 2.0)
zap.L().Debug("found steering",
zap.Float64("steering", steering),
zap.Float64("score", results[0].score),
zap.Float64("score", score),
)
return float32(steering), float32(results[0].score), nil
return float32(steering), float32(score), nil
}
var registerCallbacks = func(p *Part) error {