Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
ba9e54d1fb | |||
ce58df8a6d | |||
c5c87da714 | |||
05a0af6b74 | |||
4c3954ec76 | |||
67d1a1d8e5 |
@ -21,8 +21,6 @@ func main() {
|
||||
var modelPath string
|
||||
var edgeVerbosity int
|
||||
var imgWidth, imgHeight, horizon int
|
||||
var debug bool
|
||||
|
||||
|
||||
mqttQos := cli.InitIntFlag("MQTT_QOS", 0)
|
||||
_, mqttRetain := os.LookupEnv("MQTT_RETAIN")
|
||||
@ -36,20 +34,16 @@ func main() {
|
||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
|
||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
|
||||
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
||||
flag.BoolVar(&debug, "debug", false, "Display debug logs")
|
||||
|
||||
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
||||
flag.Parse()
|
||||
|
||||
if len(os.Args) <= 1 {
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
config := zap.NewDevelopmentConfig()
|
||||
if debug {
|
||||
config.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
|
||||
} else {
|
||||
config.Level = zap.NewAtomicLevelAt(zap.InfoLevel)
|
||||
}
|
||||
config.Level = zap.NewAtomicLevelAt(*logLevel)
|
||||
lgr, err := config.Build()
|
||||
if err != nil {
|
||||
log.Fatalf("unable to init logger: %v", err)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/cyrilix/robocar-base/service"
|
||||
"github.com/cyrilix/robocar-protobuf/go/events"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||
"github.com/disintegration/imaging"
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -15,7 +16,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -137,7 +137,6 @@ func (p *Part) onFrame(_ mqtt.Client, message mqtt.Message) {
|
||||
frameAge := now - msg.Id.CreatedAt.AsTime().UnixMilli()
|
||||
go metrics.FrameAge.Record(context.Background(), frameAge)
|
||||
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(msg.GetFrame()))
|
||||
if err != nil {
|
||||
zap.L().Error("unable to decode frame, skip frame", zap.Error(err))
|
||||
@ -193,13 +192,13 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
||||
dx = img.Bounds().Dx()
|
||||
dy = img.Bounds().Dy()
|
||||
|
||||
bb := make([]byte, dx*dy*3)
|
||||
bb := make([]uint8, dx*dy*3)
|
||||
for y := 0; y < dy; y++ {
|
||||
for x := 0; x < dx; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
bb[(y*dx+x)*3+0] = byte(float64(r) / 255.0)
|
||||
bb[(y*dx+x)*3+1] = byte(float64(g) / 255.0)
|
||||
bb[(y*dx+x)*3+2] = byte(float64(b) / 255.0)
|
||||
bb[(y*dx+x)*3+0] = uint8(float64(r) / 257.0)
|
||||
bb[(y*dx+x)*3+1] = uint8(float64(g) / 257.0)
|
||||
bb[(y*dx+x)*3+2] = uint8(float64(b) / 257.0)
|
||||
}
|
||||
}
|
||||
status = input.CopyFromBuffer(bb)
|
||||
@ -212,42 +211,24 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
||||
return 0., 0., fmt.Errorf("invoke failed: %v", status)
|
||||
}
|
||||
|
||||
output := p.interpreter.GetOutputTensor(0)
|
||||
outputSize := output.Dim(output.NumDims() - 1)
|
||||
b := make([]byte, outputSize)
|
||||
type result struct {
|
||||
score float64
|
||||
index int
|
||||
}
|
||||
status = output.CopyToBuffer(&b[0])
|
||||
if status != tflite.OK {
|
||||
return 0., 0., fmt.Errorf("output failed: %v", status)
|
||||
}
|
||||
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
||||
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
||||
|
||||
var results []result
|
||||
minScore := 0.2
|
||||
for i := 0; i < outputSize; i++ {
|
||||
score := float64(b[i]) / 255.0
|
||||
if score < minScore {
|
||||
continue
|
||||
}
|
||||
results = append(results, result{score: score, index: i})
|
||||
}
|
||||
//outputSize := output.Dim(output.NumDims() - 1)
|
||||
|
||||
if len(results) == 0 {
|
||||
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
|
||||
return 0., 0., nil
|
||||
}
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].score > results[j].score
|
||||
})
|
||||
//b := make([]byte, outputSize)
|
||||
//status = output.CopyToBuffer(&b[0])
|
||||
//if status != tflite.OK {
|
||||
// return 0., 0., fmt.Errorf("output failed: %v", status)
|
||||
//}
|
||||
|
||||
steering := float64(results[0].index)*(2./float64(outputSize)) - 1
|
||||
steering, score := tools.LinearBin(output, 15, -1, 2.0)
|
||||
//steering, score := tools.LinearBin(b, 15, -1, 2.0)
|
||||
zap.L().Debug("found steering",
|
||||
zap.Float64("steering", steering),
|
||||
zap.Float64("score", results[0].score),
|
||||
zap.Float64("score", score),
|
||||
)
|
||||
return float32(steering), float32(results[0].score), nil
|
||||
return float32(steering), float32(score), nil
|
||||
}
|
||||
|
||||
var registerCallbacks = func(p *Part) error {
|
||||
|
40
pkg/tools/tools.go
Normal file
40
pkg/tools/tools.go
Normal file
@ -0,0 +1,40 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// LinearBin perform inverse linear_bin, taking
|
||||
func LinearBin(arr []uint8, n int, offset int, r float64) (float64, float64) {
|
||||
outputSize := len(arr)
|
||||
type result struct {
|
||||
score float64
|
||||
index int
|
||||
}
|
||||
|
||||
var results []result
|
||||
minScore := 0.2
|
||||
for i := 0; i < outputSize; i++ {
|
||||
score := float64(int(arr[i])) / 255.0
|
||||
if score < minScore {
|
||||
continue
|
||||
}
|
||||
results = append(results, result{score: score, index: i})
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
|
||||
return 0., 0.
|
||||
}
|
||||
zap.S().Debugf("raw result: %v", results)
|
||||
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].score > results[j].score
|
||||
})
|
||||
|
||||
b := results[0].index
|
||||
a := float64(b)*(r/(float64(n)+float64(offset))) + float64(offset)
|
||||
return a, results[0].score
|
||||
}
|
77
pkg/tools/tools_test.go
Normal file
77
pkg/tools/tools_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_LinearBin(t *testing.T) {
|
||||
type args struct {
|
||||
arr []byte
|
||||
n int
|
||||
offset int
|
||||
r float64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want float64
|
||||
want1 float64
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
args: args{
|
||||
arr: []byte{0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0},
|
||||
n: 15,
|
||||
offset: -1,
|
||||
r: 2.0,
|
||||
},
|
||||
want: 0.,
|
||||
want1: 1.0,
|
||||
},
|
||||
{
|
||||
name: "left",
|
||||
args: args{
|
||||
arr: []byte{255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
n: 15,
|
||||
offset: -1,
|
||||
r: 2.0,
|
||||
},
|
||||
want: -1.,
|
||||
want1: 1.0,
|
||||
},
|
||||
{
|
||||
name: "right",
|
||||
args: args{
|
||||
arr: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255},
|
||||
n: 15,
|
||||
offset: -1,
|
||||
r: 2.0,
|
||||
},
|
||||
want: 1.,
|
||||
want1: 1.0,
|
||||
},
|
||||
{
|
||||
name: "right",
|
||||
args: args{
|
||||
arr: []byte{0, 0, 0, 0, 0, 0, 0, 5, 10, 15, 20, 40, 100, 60, 5},
|
||||
n: 15,
|
||||
offset: -1,
|
||||
r: 2.0,
|
||||
},
|
||||
want: 0.7142857142857142,
|
||||
want1: 0.39215686274509803,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := LinearBin(tt.args.arr, tt.args.n, tt.args.offset, tt.args.r)
|
||||
if got != tt.want {
|
||||
t.Errorf("linearBin() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("linearBin() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user