feat(on-objects): Steering corrector implementation

This commit is contained in:
Cyrille Nofficial 2022-08-23 22:08:07 +02:00
parent 158f08a76f
commit d5d4d908a0
8 changed files with 740 additions and 99 deletions

View File

@ -16,6 +16,10 @@ const (
func main() { func main() {
var mqttBroker, username, password, clientId string var mqttBroker, username, password, clientId string
var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string
var imgWidth, imgHeight int
var enableObjectsCorrection, enableObjectsCorrectionOnUserMode bool
var gridMapConfig, objectsMoveFactorsConfig string
var deltaMiddle float64
mqttQos := cli.InitIntFlag("MQTT_QOS", 0) mqttQos := cli.InitIntFlag("MQTT_QOS", 0)
_, mqttRetain := os.LookupEnv("MQTT_RETAIN") _, mqttRetain := os.LookupEnv("MQTT_RETAIN")
@ -27,7 +31,13 @@ func main() {
flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set") flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set")
flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set") flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set")
flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set") flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set")
flag.IntVar(&imgWidth, "image-width", 160, "Video pixels width")
flag.IntVar(&imgHeight, "image-height", 128, "Video pixels height")
flag.BoolVar(&enableObjectsCorrection, "enable-objects-correction", false, "Adjust steering to avoid objects")
flag.BoolVar(&enableObjectsCorrectionOnUserMode, "enable-objects-correction-user", false, "Adjust steering to avoid objects on user mode driving")
flag.StringVar(&gridMapConfig, "grid-map-config", "", "Json file path to configure grid object correction")
flag.StringVar(&objectsMoveFactorsConfig, "objects-move-factors-config", "", "Json file path to configure objects move corrections")
flag.Float64Var(&deltaMiddle, "delta-middle", 0.1, "Half Percent zone to interpret as straight")
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
flag.Parse() flag.Parse()
@ -50,13 +60,36 @@ func main() {
}() }()
zap.ReplaceGlobals(lgr) zap.ReplaceGlobals(lgr)
zap.S().Infof("steering topic : %s", steeringTopic)
zap.S().Infof("rc topic : %s", rcSteeringTopic)
zap.S().Infof("tflite steering topic : %s", tfSteeringTopic)
zap.S().Infof("drive mode topic : %s", driveModeTopic)
zap.S().Infof("objects topic : %s", objectsTopic)
zap.S().Infof("objects correction enabled : %v", enableObjectsCorrection)
zap.S().Infof("objects correction on user mode : %v", enableObjectsCorrectionOnUserMode)
zap.S().Infof("grid map file config : %v", gridMapConfig)
zap.S().Infof("objects move factors grid config: %v", objectsMoveFactorsConfig)
zap.S().Infof("image width x height : %v x %v", imgWidth, imgHeight)
client, err := cli.Connect(mqttBroker, username, password, clientId) client, err := cli.Connect(mqttBroker, username, password, clientId)
if err != nil { if err != nil {
log.Fatalf("unable to connect to mqtt bus: %v", err) log.Fatalf("unable to connect to mqtt bus: %v", err)
} }
defer client.Disconnect(50) defer client.Disconnect(50)
p := steering.NewController(client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic) p := steering.NewController(
client,
steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
steering.WithCorrector(
steering.NewGridCorrector(
steering.WidthDeltaMiddle(deltaMiddle),
steering.WithGridMap(gridMapConfig),
steering.WithObjectMoveFactors(objectsMoveFactorsConfig),
steering.WithImageSize(imgWidth, imgHeight),
),
),
steering.WithObjectsCorrectionEnabled(enableObjectsCorrection, enableObjectsCorrectionOnUserMode),
)
defer p.Stop() defer p.Stop()
cli.HandleExit(p) cli.HandleExit(p)

View File

@ -1,6 +1,7 @@
package steering package steering
import ( import (
"github.com/cyrilix/robocar-protobuf/go/events"
"gocv.io/x/gocv" "gocv.io/x/gocv"
"image" "image"
) )
@ -14,3 +15,43 @@ func GroupBBoxes(bboxes []image.Rectangle) []image.Rectangle {
} }
return gocv.GroupRectangles(bboxes, 1, 0.2) return gocv.GroupRectangles(bboxes, 1, 0.2)
} }
func GroupObjects(objects []*events.Object, imgWidth, imgHeight int) []*events.Object {
if len(objects) == 0 {
return []*events.Object{}
}
if len(objects) == 1 {
return []*events.Object{objects[0]}
}
rectangles := make([]image.Rectangle, 0, len(objects))
for _, o := range objects {
rectangles = append(rectangles, *objectToRect(o, imgWidth, imgHeight))
}
grp := gocv.GroupRectangles(rectangles, 1, 0.2)
result := make([]*events.Object, 0, len(grp))
for _, r := range grp {
result = append(result, rectToObject(&r, imgWidth, imgHeight))
}
return result
}
func objectToRect(object *events.Object, imgWidth, imgHeight int) *image.Rectangle {
r := image.Rect(
int(object.Left*float32(imgWidth)),
int(object.Top*float32(imgHeight)),
int(object.Right*float32(imgWidth)),
int(object.Bottom*float32(imgHeight)),
)
return &r
}
func rectToObject(r *image.Rectangle, imgWidth, imgHeight int) *events.Object {
return &events.Object{
Type: events.TypeObject_ANY,
Left: float32(r.Min.X) / float32(imgWidth),
Top: float32(r.Min.Y) / float32(imgHeight),
Right: float32(r.Max.X) / float32(imgWidth),
Bottom: float32(r.Max.Y) / float32(imgHeight),
Confidence: -1,
}
}

View File

@ -3,6 +3,7 @@ package steering
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/cyrilix/robocar-protobuf/go/events"
"go.uber.org/zap" "go.uber.org/zap"
"gocv.io/x/gocv" "gocv.io/x/gocv"
"image" "image"
@ -25,6 +26,7 @@ type BBox struct {
var ( var (
dataBBoxes map[string][]image.Rectangle dataBBoxes map[string][]image.Rectangle
dataObjects map[string][]*events.Object
dataImages map[string]*gocv.Mat dataImages map[string]*gocv.Mat
) )
@ -32,6 +34,7 @@ func init() {
// TODO: empty img without bbox // TODO: empty img without bbox
dataNames := []string{"01", "02", "03", "04"} dataNames := []string{"01", "02", "03", "04"}
dataBBoxes = make(map[string][]image.Rectangle, len(dataNames)) dataBBoxes = make(map[string][]image.Rectangle, len(dataNames))
dataObjects = make(map[string][]*events.Object, len(dataNames))
dataImages = make(map[string]*gocv.Mat, len(dataNames)) dataImages = make(map[string]*gocv.Mat, len(dataNames))
for _, dataName := range dataNames { for _, dataName := range dataNames {
@ -40,6 +43,7 @@ func init() {
zap.S().Panicf("unable to load data test: %v", err) zap.S().Panicf("unable to load data test: %v", err)
} }
dataBBoxes[dataName] = bboxesToRectangles(bb, img.Cols(), img.Rows()) dataBBoxes[dataName] = bboxesToRectangles(bb, img.Cols(), img.Rows())
dataObjects[dataName] = bboxesToObjects(bb)
dataImages[dataName] = img dataImages[dataName] = img
} }
} }
@ -52,6 +56,20 @@ func bboxesToRectangles(bboxes []BBox, imgWidth, imgHeiht int) []image.Rectangle
return rects return rects
} }
func bboxesToObjects(bboxes []BBox) []*events.Object {
objects := make([]*events.Object, 0, len(bboxes))
for _, bb := range bboxes {
objects = append(objects, &events.Object{
Type: events.TypeObject_ANY,
Left: bb.Left,
Top: bb.Top,
Right: bb.Right,
Bottom: bb.Bottom,
Confidence: bb.Confidence,
})
}
return objects
}
func (bb *BBox) toRect(imgWidth, imgHeight int) image.Rectangle { func (bb *BBox) toRect(imgWidth, imgHeight int) image.Rectangle {
return image.Rect( return image.Rect(
int(bb.Left*float32(imgWidth)), int(bb.Left*float32(imgWidth)),
@ -228,3 +246,59 @@ func TestGroupBBoxes(t *testing.T) {
}) })
} }
} }
func TestGroupObjects(t *testing.T) {
type args struct {
dataName string
}
tests := []struct {
name string
args args
want []*events.Object
}{
{
name: "groupbbox-01",
args: args{
dataName: "01",
},
want: []*events.Object{
{Left: 0.26660156, Top: 0.1706543, Right: 0.5258789, Bottom: 0.47583008, Confidence: 0.4482422},
},
},
{
name: "groupbbox-02",
args: args{
dataName: "02",
},
want: []*events.Object{
{Left: 0.15625, Top: 0.108333334, Right: 0.6875, Bottom: 0.6666667, Confidence: -1},
},
},
{
name: "groupbbox-03",
args: args{
dataName: "03",
},
want: []*events.Object{
{Top: 0.14166667, Right: 0.21875, Bottom: 0.64166665, Confidence: -1},
},
},
{
name: "groupbbox-04",
args: args{
dataName: "04",
},
want: []*events.Object{
{Left: 0.80625, Top: 0.083333336, Right: 0.99375, Bottom: 0.53333336, Confidence: -1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
img := dataImages[tt.args.dataName]
got := GroupObjects(dataObjects[tt.args.dataName], img.Cols(), img.Rows())
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GroupObjects() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -1,6 +1,7 @@
package steering package steering
import ( import (
"fmt"
"github.com/cyrilix/robocar-base/service" "github.com/cyrilix/robocar-base/service"
"github.com/cyrilix/robocar-protobuf/go/events" "github.com/cyrilix/robocar-protobuf/go/events"
mqtt "github.com/eclipse/paho.mqtt.golang" mqtt "github.com/eclipse/paho.mqtt.golang"
@ -9,8 +10,48 @@ import (
"sync" "sync"
) )
func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string) *Controller { var (
return &Controller{ defaultGridMap = GridMap{
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
Data: [][]float64{
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0., 0.25, -0.25, 0., 0.},
{0., 0.25, 0.5, -0.5, -0.25, 0.},
{0.25, 0.5, 1, -1, -0.5, -0.25},
},
}
defaultObjectFactors = GridMap{
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
Data: [][]float64{
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0.25, 0, 0, -0.25, 0.},
{0.5, 0.25, 0, 0, -0.5, -0.25},
},
}
)
type Option func(c *Controller)
func WithCorrector(c Corrector) Option {
return func(ctrl *Controller) {
ctrl.corrector = c
}
}
func WithObjectsCorrectionEnabled(enabled, enabledOnUserDrive bool) Option {
return func(ctrl *Controller) {
ctrl.enableCorrection = enabled
ctrl.enableCorrectionOnUser = enabledOnUserDrive
}
}
func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string, options ...Option) *Controller {
c := &Controller{
client: client, client: client,
steeringTopic: steeringTopic, steeringTopic: steeringTopic,
driveModeTopic: driveModeTopic, driveModeTopic: driveModeTopic,
@ -18,8 +59,12 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering
tfSteeringTopic: tfSteeringTopic, tfSteeringTopic: tfSteeringTopic,
objectsTopic: objectsTopic, objectsTopic: objectsTopic,
driveMode: events.DriveMode_USER, driveMode: events.DriveMode_USER,
corrector: NewGridCorrector(),
} }
for _, o := range options {
o(c)
}
return c
} }
type Controller struct { type Controller struct {
@ -35,26 +80,28 @@ type Controller struct {
muObjects sync.RWMutex muObjects sync.RWMutex
objects []*events.Object objects []*events.Object
debug bool corrector Corrector
enableCorrection bool
enableCorrectionOnUser bool
} }
func (p *Controller) Start() error { func (c *Controller) Start() error {
if err := registerCallbacks(p); err != nil { if err := registerCallbacks(c); err != nil {
zap.S().Errorf("unable to rgeister callbacks: %v", err) zap.S().Errorf("unable to register callbacks: %v", err)
return err return err
} }
p.cancel = make(chan interface{}) c.cancel = make(chan interface{})
<-p.cancel <-c.cancel
return nil return nil
} }
func (p *Controller) Stop() { func (c *Controller) Stop() {
close(p.cancel) close(c.cancel)
service.StopService("throttle", p.client, p.driveModeTopic, p.rcSteeringTopic, p.tfSteeringTopic) service.StopService("throttle", c.client, c.driveModeTopic, c.rcSteeringTopic, c.tfSteeringTopic)
} }
func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { func (c *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
var msg events.ObjectsMessage var msg events.ObjectsMessage
err := proto.Unmarshal(message.Payload(), &msg) err := proto.Unmarshal(message.Payload(), &msg)
if err != nil { if err != nil {
@ -62,12 +109,13 @@ func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
return return
} }
p.muObjects.Lock() c.muObjects.Lock()
defer p.muObjects.Unlock() defer c.muObjects.Unlock()
p.objects = msg.GetObjects() c.objects = msg.GetObjects()
zap.S().Debugf("%v object(s) received", len(c.objects))
} }
func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { func (c *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
var msg events.DriveModeMessage var msg events.DriveModeMessage
err := proto.Unmarshal(message.Payload(), &msg) err := proto.Unmarshal(message.Payload(), &msg)
if err != nil { if err != nil {
@ -75,46 +123,89 @@ func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
return return
} }
p.muDriveMode.Lock() c.muDriveMode.Lock()
defer p.muDriveMode.Unlock() defer c.muDriveMode.Unlock()
p.driveMode = msg.GetDriveMode() c.driveMode = msg.GetDriveMode()
} }
func (p *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) { func (c *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) {
p.muDriveMode.RLock() c.muDriveMode.RLock()
defer p.muDriveMode.RUnlock() defer c.muDriveMode.RUnlock()
if p.debug {
var evt events.SteeringMessage if c.driveMode != events.DriveMode_USER {
err := proto.Unmarshal(message.Payload(), &evt) return
}
payload := message.Payload()
evt := &events.SteeringMessage{}
err := proto.Unmarshal(payload, evt)
if err != nil { if err != nil {
zap.S().Debugf("unable to unmarshal rc event: %v", err) zap.S().Debugf("unable to unmarshal rc event: %v", err)
} else { } else {
zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering()) zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering())
} }
}
if p.driveMode == events.DriveMode_USER { if c.enableCorrection && c.enableCorrectionOnUser {
// Republish same content payload, err = c.adjustSteering(evt)
payload := message.Payload()
publish(p.client, p.steeringTopic, &payload)
}
}
func (p *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) {
p.muDriveMode.RLock()
defer p.muDriveMode.RUnlock()
if p.debug {
var evt events.SteeringMessage
err := proto.Unmarshal(message.Payload(), &evt)
if err != nil { if err != nil {
zap.S().Debugf("unable to unmarshal tensorflow event: %v", err) zap.S().Errorf("unable to adjust steering, skip message: %v", err)
return
}
}
publish(c.client, c.steeringTopic, &payload)
}
func (c *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) {
c.muDriveMode.RLock()
defer c.muDriveMode.RUnlock()
if c.driveMode != events.DriveMode_PILOT {
// User mode, skip new message
return
}
evt := &events.SteeringMessage{}
err := proto.Unmarshal(message.Payload(), evt)
if err != nil {
zap.S().Errorf("unable to unmarshal tensorflow event: %v", err)
return
} else { } else {
zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering()) zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering())
} }
}
if p.driveMode == events.DriveMode_PILOT {
// Republish same content
payload := message.Payload() payload := message.Payload()
publish(p.client, p.steeringTopic, &payload) if c.enableCorrection {
payload, err = c.adjustSteering(evt)
if err != nil {
zap.S().Errorf("unable to adjust steering, skip message: %v", err)
return
} }
}
publish(c.client, c.steeringTopic, &payload)
}
func (c *Controller) adjustSteering(evt *events.SteeringMessage) ([]byte, error) {
steering := float64(evt.GetSteering())
steering = c.corrector.AdjustFromObjectPosition(steering, c.Objects())
zap.S().Debugf("adjust steering to avoid objects: %v -> %v", evt.GetSteering(), steering)
evt.Steering = float32(steering)
// override payload content
payload, err := proto.Marshal(evt)
if err != nil {
return nil, fmt.Errorf("unable to marshal steering message with new value, skip message: %v", err)
}
return payload, nil
}
func (c *Controller) Objects() []*events.Object {
c.muObjects.RLock()
defer c.muObjects.RUnlock()
res := make([]*events.Object, 0, len(c.objects))
for _, o := range c.objects {
res = append(res, o)
}
zap.S().Debugf("copy object from %v to %v", c.objects, res)
return res
} }
var registerCallbacks = func(p *Controller) error { var registerCallbacks = func(p *Controller) error {

View File

@ -123,3 +123,198 @@ func TestDefaultSteering(t *testing.T) {
} }
} }
} }
type StaticCorrector struct {
delta float64
}
func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
return s.delta
}
func TestController_Start(t *testing.T) {
oldRegister := registerCallbacks
oldPublish := publish
defer func() {
registerCallbacks = oldRegister
publish = oldPublish
}()
registerCallbacks = func(p *Controller) error {
return nil
}
waitPublish := sync.WaitGroup{}
var muEventsPublished sync.Mutex
eventsPublished := make(map[string][]byte)
publish = func(client mqtt.Client, topic string, payload *[]byte) {
muEventsPublished.Lock()
defer muEventsPublished.Unlock()
eventsPublished[topic] = *payload
waitPublish.Done()
}
steeringTopic := "topic/steering"
driveModeTopic := "topic/driveMode"
rcSteeringTopic := "topic/rcSteering"
tfSteeringTopic := "topic/tfSteering"
objectsTopic := "topic/objects"
type fields struct {
driveMode events.DriveMode
enableCorrection bool
enableCorrectionOnUser bool
}
type msgEvents struct {
driveMode events.DriveModeMessage
rcSteering events.SteeringMessage
tfSteering events.SteeringMessage
expectedSteering events.SteeringMessage
objects events.ObjectsMessage
}
tests := []struct {
name string
fields fields
msgEvents msgEvents
correctionOnObject float64
want events.SteeringMessage
wantErr bool
}{
{
name: "On user drive mode, none correction",
fields: fields{
driveMode: events.DriveMode_USER,
enableCorrection: false,
enableCorrectionOnUser: false,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
},
{
name: "On pilot drive mode, none correction",
fields: fields{
driveMode: events.DriveMode_PILOT,
enableCorrection: false,
enableCorrectionOnUser: false,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_PILOT},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
},
{
name: "On pilot drive mode, correction enabled",
fields: fields{
driveMode: events.DriveMode_PILOT,
enableCorrection: true,
enableCorrectionOnUser: false,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_PILOT},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0},
},
{
name: "On pilot drive mode, all corrections enabled",
fields: fields{
driveMode: events.DriveMode_PILOT,
enableCorrection: true,
enableCorrectionOnUser: true,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_PILOT},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0},
},
{
name: "On user drive mode, only correction PILOT enabled",
fields: fields{
driveMode: events.DriveMode_PILOT,
enableCorrection: true,
enableCorrectionOnUser: false,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
},
{
name: "On user drive mode, all corrections enabled",
fields: fields{
driveMode: events.DriveMode_USER,
enableCorrection: true,
enableCorrectionOnUser: true,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewController(nil,
steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser),
WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}),
)
go c.Start()
time.Sleep(1 * time.Millisecond)
// Publish events and wait generation of new steering message
waitPublish.Add(1)
c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode))
c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering))
c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering))
c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects))
waitPublish.Wait()
var msg events.SteeringMessage
muEventsPublished.Lock()
err := proto.Unmarshal(eventsPublished[steeringTopic], &msg)
if err != nil {
t.Errorf("unable to unmarshall response: %v", err)
t.Fail()
}
muEventsPublished.Unlock()
if msg.GetSteering() != tt.want.GetSteering() {
t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering())
}
})
}
}

View File

@ -8,8 +8,90 @@ import (
"os" "os"
) )
type Corrector struct { type Corrector interface {
AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64
}
type OptionCorrector func(c *GridCorrector)
func WithGridMap(configPath string) OptionCorrector {
var gm *GridMap
if configPath == "" {
zap.S().Warnf("no configuration defined for grid map, use default")
gm = &defaultGridMap
} else {
var err error
gm, err = loadConfig(configPath)
if err != nil {
zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err)
}
}
return func(c *GridCorrector) {
c.gridMap = gm
}
}
func WithObjectMoveFactors(configPath string) OptionCorrector {
var omf *GridMap
if configPath == "" {
zap.S().Warnf("no configuration defined for objects move factors, use default")
omf = &defaultObjectFactors
} else {
var err error
omf, err = loadConfig(configPath)
if err != nil {
zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err)
}
}
return func(c *GridCorrector) {
c.objectMoveFactors = omf
}
}
func loadConfig(configPath string) (*GridMap, error) {
content, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("unable to load grid-map config from file '%v': %w", configPath, err)
}
var gm GridMap
err = json.Unmarshal(content, &gm)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal json config '%s': %w", configPath, err)
}
return &gm, nil
}
func WithImageSize(width, height int) OptionCorrector {
return func(c *GridCorrector) {
c.imgWidth = width
c.imgHeight = height
}
}
func WidthDeltaMiddle(d float64) OptionCorrector {
return func(c *GridCorrector) {
c.deltaMiddle = d
}
}
func NewGridCorrector(options ...OptionCorrector) *GridCorrector {
c := &GridCorrector{
gridMap: &defaultGridMap,
objectMoveFactors: &defaultObjectFactors,
deltaMiddle: 0.1,
imgWidth: 160,
imgHeight: 120,
}
for _, o := range options {
o(c)
}
return c
}
type GridCorrector struct {
gridMap *GridMap gridMap *GridMap
objectMoveFactors *GridMap
deltaMiddle float64
imgWidth, imgHeight int
} }
/* /*
@ -39,24 +121,64 @@ AdjustFromObjectPosition modify steering value according object positions
3. If current steering != 0 (turn on left or right), shift right and left values proportionnaly to current steering and 3. If current steering != 0 (turn on left or right), shift right and left values proportionnaly to current steering and
apply 2. apply 2.
*/
func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
// TODO, group rectangle
: -1 -0.66 -0.33 0 0.33 0.66 1
0% |-----|-----|-----|-----|-----|-----|
: | 0 | 0 | 0 | 0 | 0 | 0 |
20% |-----|-----|-----|-----|-----|-----|
: | 0.2 | 0.1 | 0 | 0 |-0.1 |-0.2 |
40% |-----|-----|-----|-----|-----|-----|
: | ... | ... | ... | ... | ... | ... |
*/
func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
zap.S().Debugf("%v objects to avoid", len(objects))
if len(objects) == 0 { if len(objects) == 0 {
return currentSteering return currentSteering
} }
grpObjs := GroupObjects(objects, c.imgWidth, c.imgHeight)
// get nearest object // get nearest object
nearest, err := c.nearObject(objects) nearest, err := c.nearObject(grpObjs)
if err != nil { if err != nil {
zap.S().Warnf("unexpected error on nearest seach object, ignore objects: %v", err) zap.S().Warnf("unexpected error on nearest search object, ignore objects: %v", err)
return currentSteering return currentSteering
} }
if currentSteering > -0.1 && currentSteering < 0.1 { if currentSteering > -1*c.deltaMiddle && currentSteering < c.deltaMiddle {
// Straight
return currentSteering + c.computeDeviation(nearest)
} else {
// Turn to right or left, so search to avoid collision with objects on the right
// Apply factor to object to move it at middle. This factor is function of distance
factor, err := c.objectMoveFactors.ValueOf(float64(nearest.Right), float64(nearest.Bottom))
if err != nil {
zap.S().Warnf("unable to compute factor to apply to object: %v", err)
return currentSteering
}
objMoved := events.Object{
Type: nearest.Type,
Left: nearest.Left + float32(currentSteering*factor),
Top: nearest.Top,
Right: nearest.Right + float32(currentSteering*factor),
Bottom: nearest.Bottom,
Confidence: nearest.Confidence,
}
result := currentSteering + c.computeDeviation(&objMoved)
if result < -1. {
result = -1.
}
if result > 1. {
result = 1.
}
return result
}
}
func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 {
var delta float64 var delta float64
var err error
zap.S().Debugf("search delta value for bottom limit: %v", nearest.Bottom)
if nearest.Left < 0 && nearest.Right < 0 { if nearest.Left < 0 && nearest.Right < 0 {
delta, err = c.gridMap.ValueOf(float64(nearest.Right)*2-1., float64(nearest.Bottom)) delta, err = c.gridMap.ValueOf(float64(nearest.Right)*2-1., float64(nearest.Bottom))
} }
@ -69,15 +191,11 @@ func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []
zap.S().Warnf("unable to compute delta to apply to steering, skip correction: %v", err) zap.S().Warnf("unable to compute delta to apply to steering, skip correction: %v", err)
delta = 0 delta = 0
} }
return currentSteering + delta zap.S().Debugf("new deviation computed: %v", delta)
} return delta
// Search if current steering is near of Right or Left
return currentSteering
} }
func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) { func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) {
if len(objects) == 0 { if len(objects) == 0 {
return nil, fmt.Errorf("list objects must contain at least one object") return nil, fmt.Errorf("list objects must contain at least one object")
} }

View File

@ -39,19 +39,21 @@ var (
Bottom: 0.9, Bottom: 0.9,
Confidence: 0.9, Confidence: 0.9,
} }
) objectOnRightNear = events.Object{
Type: events.TypeObject_ANY,
var ( Left: 0.7,
defaultGridMap = GridMap{ Top: 0.8,
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, Right: 0.9,
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, Bottom: 0.9,
Data: [][]float64{ Confidence: 0.9,
{0., 0., 0., 0., 0., 0.}, }
{0., 0., 0., 0., 0., 0.}, objectOnLeftNear = events.Object{
{0., 0., 0.25, -0.25, 0., 0.}, Type: events.TypeObject_ANY,
{0., 0.25, 0.5, -0.5, -0.25, 0.}, Left: 0.1,
{0.25, 0.5, 1, -1, -0.5, -0.25}, Top: 0.8,
}, Right: 0.3,
Bottom: 0.9,
Confidence: 0.9,
} }
) )
@ -128,7 +130,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
currentSteering: -0.9, currentSteering: -0.9,
objects: []*events.Object{&objectOnMiddleNear}, objects: []*events.Object{&objectOnMiddleNear},
}, },
want: -0.4, want: -1,
}, },
{ {
name: "run to right with 1 near object", name: "run to right with 1 near object",
@ -136,16 +138,28 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
currentSteering: 0.9, currentSteering: 0.9,
objects: []*events.Object{&objectOnMiddleNear}, objects: []*events.Object{&objectOnMiddleNear},
}, },
want: 0.4, want: 1.,
},
{
name: "run to right with 1 near object on the right",
args: args{
currentSteering: 0.9,
objects: []*events.Object{&objectOnRightNear},
},
want: 1.,
},
{
name: "run to left with 1 near object on the left",
args: args{
currentSteering: -0.9,
objects: []*events.Object{&objectOnLeftNear},
},
want: -0.65,
}, },
// Todo Object on left/right near/distant
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Corrector{ c := NewGridCorrector()
gridMap: &defaultGridMap,
}
if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want { if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want {
t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want) t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want)
} }
@ -190,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Corrector{} c := &GridCorrector{}
got, err := c.nearObject(tt.args.objects) got, err := c.nearObject(tt.args.objects)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr)
@ -399,3 +413,67 @@ func TestGridMap_ValueOf(t *testing.T) {
}) })
} }
} }
func TestWithGridMap(t *testing.T) {
type args struct {
config string
}
tests := []struct {
name string
args args
want GridMap
}{
{
name: "default value",
args: args{config: ""},
want: defaultGridMap,
},
{
name: "load config",
args: args{config: "test_data/config.json"},
want: defaultGridMap,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := GridCorrector{}
got := WithGridMap(tt.args.config)
got(&c)
if !reflect.DeepEqual(*c.gridMap, tt.want) {
t.Errorf("WithGridMap() = %v, want %v", *c.gridMap, tt.want)
}
})
}
}
func TestWithObjectMoveFactors(t *testing.T) {
type args struct {
config string
}
tests := []struct {
name string
args args
want GridMap
}{
{
name: "default value",
args: args{config: ""},
want: defaultObjectFactors,
},
{
name: "load config",
args: args{config: "test_data/omf-config.json"},
want: defaultObjectFactors,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := GridCorrector{}
got := WithObjectMoveFactors(tt.args.config)
got(&c)
if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) {
t.Errorf("WithObjectMoveFactors() = %v, want %v", *c.objectMoveFactors, tt.want)
}
})
}
}

View File

@ -0,0 +1,11 @@
{
"steering_steps":[-1, -0.66, -0.33, 0, 0.33, 0.66, 1],
"distance_steps": [0, 0.2, 0.4, 0.6, 0.8, 1],
"data": [
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0.25, 0, 0, -0.25, 0],
[0.5, 0.25, 0, 0, -0.5, -0.25]
]
}