From d5d4d908a041ab3e2cad71315dba23ec76f490d2 Mon Sep 17 00:00:00 2001 From: Cyrille Nofficial Date: Tue, 23 Aug 2022 22:08:07 +0200 Subject: [PATCH] feat(on-objects): Steering corrector implementation --- cmd/rc-steering/rc-steering.go | 37 ++++- pkg/steering/bbox.go | 41 ++++++ pkg/steering/bbox_test.go | 78 +++++++++- pkg/steering/controller.go | 185 +++++++++++++++++------ pkg/steering/controller_test.go | 195 +++++++++++++++++++++++++ pkg/steering/corrector.go | 172 ++++++++++++++++++---- pkg/steering/corrector_test.go | 120 ++++++++++++--- pkg/steering/test_data/omf-config.json | 11 ++ 8 files changed, 740 insertions(+), 99 deletions(-) create mode 100644 pkg/steering/test_data/omf-config.json diff --git a/cmd/rc-steering/rc-steering.go b/cmd/rc-steering/rc-steering.go index f11d84a..c6b44a8 100644 --- a/cmd/rc-steering/rc-steering.go +++ b/cmd/rc-steering/rc-steering.go @@ -16,6 +16,10 @@ const ( func main() { var mqttBroker, username, password, clientId string var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string + var imgWidth, imgHeight int + var enableObjectsCorrection, enableObjectsCorrectionOnUserMode bool + var gridMapConfig, objectsMoveFactorsConfig string + var deltaMiddle float64 mqttQos := cli.InitIntFlag("MQTT_QOS", 0) _, mqttRetain := os.LookupEnv("MQTT_RETAIN") @@ -27,7 +31,13 @@ func main() { flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set") flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set") flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set") - + flag.IntVar(&imgWidth, "image-width", 160, "Video pixels width") + flag.IntVar(&imgHeight, "image-height", 128, "Video pixels height") + flag.BoolVar(&enableObjectsCorrection, "enable-objects-correction", false, "Adjust steering to avoid objects") + flag.BoolVar(&enableObjectsCorrectionOnUserMode, "enable-objects-correction-user", false, "Adjust steering to avoid objects on user mode driving") + flag.StringVar(&gridMapConfig, "grid-map-config", "", "Json file path to configure grid object correction") + flag.StringVar(&objectsMoveFactorsConfig, "objects-move-factors-config", "", "Json file path to configure objects move corrections") + flag.Float64Var(&deltaMiddle, "delta-middle", 0.1, "Half Percent zone to interpret as straight") logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") flag.Parse() @@ -50,13 +60,36 @@ func main() { }() zap.ReplaceGlobals(lgr) + zap.S().Infof("steering topic : %s", steeringTopic) + zap.S().Infof("rc topic : %s", rcSteeringTopic) + zap.S().Infof("tflite steering topic : %s", tfSteeringTopic) + zap.S().Infof("drive mode topic : %s", driveModeTopic) + zap.S().Infof("objects topic : %s", objectsTopic) + zap.S().Infof("objects correction enabled : %v", enableObjectsCorrection) + zap.S().Infof("objects correction on user mode : %v", enableObjectsCorrectionOnUserMode) + zap.S().Infof("grid map file config : %v", gridMapConfig) + zap.S().Infof("objects move factors grid config: %v", objectsMoveFactorsConfig) + zap.S().Infof("image width x height : %v x %v", imgWidth, imgHeight) + client, err := cli.Connect(mqttBroker, username, password, clientId) if err != nil { log.Fatalf("unable to connect to mqtt bus: %v", err) } defer client.Disconnect(50) - p := steering.NewController(client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic) + p := steering.NewController( + client, + steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic, + steering.WithCorrector( + steering.NewGridCorrector( + steering.WidthDeltaMiddle(deltaMiddle), + steering.WithGridMap(gridMapConfig), + steering.WithObjectMoveFactors(objectsMoveFactorsConfig), + steering.WithImageSize(imgWidth, imgHeight), + ), + ), + steering.WithObjectsCorrectionEnabled(enableObjectsCorrection, enableObjectsCorrectionOnUserMode), + ) defer p.Stop() cli.HandleExit(p) diff --git a/pkg/steering/bbox.go b/pkg/steering/bbox.go index 3a9db76..05dbb34 100644 --- a/pkg/steering/bbox.go +++ b/pkg/steering/bbox.go @@ -1,6 +1,7 @@ package steering import ( + "github.com/cyrilix/robocar-protobuf/go/events" "gocv.io/x/gocv" "image" ) @@ -14,3 +15,43 @@ func GroupBBoxes(bboxes []image.Rectangle) []image.Rectangle { } return gocv.GroupRectangles(bboxes, 1, 0.2) } +func GroupObjects(objects []*events.Object, imgWidth, imgHeight int) []*events.Object { + if len(objects) == 0 { + return []*events.Object{} + } + if len(objects) == 1 { + return []*events.Object{objects[0]} + } + + rectangles := make([]image.Rectangle, 0, len(objects)) + for _, o := range objects { + rectangles = append(rectangles, *objectToRect(o, imgWidth, imgHeight)) + } + grp := gocv.GroupRectangles(rectangles, 1, 0.2) + result := make([]*events.Object, 0, len(grp)) + for _, r := range grp { + result = append(result, rectToObject(&r, imgWidth, imgHeight)) + } + return result +} + +func objectToRect(object *events.Object, imgWidth, imgHeight int) *image.Rectangle { + r := image.Rect( + int(object.Left*float32(imgWidth)), + int(object.Top*float32(imgHeight)), + int(object.Right*float32(imgWidth)), + int(object.Bottom*float32(imgHeight)), + ) + return &r +} + +func rectToObject(r *image.Rectangle, imgWidth, imgHeight int) *events.Object { + return &events.Object{ + Type: events.TypeObject_ANY, + Left: float32(r.Min.X) / float32(imgWidth), + Top: float32(r.Min.Y) / float32(imgHeight), + Right: float32(r.Max.X) / float32(imgWidth), + Bottom: float32(r.Max.Y) / float32(imgHeight), + Confidence: -1, + } +} diff --git a/pkg/steering/bbox_test.go b/pkg/steering/bbox_test.go index b85e106..d760b19 100644 --- a/pkg/steering/bbox_test.go +++ b/pkg/steering/bbox_test.go @@ -3,6 +3,7 @@ package steering import ( "encoding/json" "fmt" + "github.com/cyrilix/robocar-protobuf/go/events" "go.uber.org/zap" "gocv.io/x/gocv" "image" @@ -24,14 +25,16 @@ type BBox struct { } var ( - dataBBoxes map[string][]image.Rectangle - dataImages map[string]*gocv.Mat + dataBBoxes map[string][]image.Rectangle + dataObjects map[string][]*events.Object + dataImages map[string]*gocv.Mat ) func init() { // TODO: empty img without bbox dataNames := []string{"01", "02", "03", "04"} dataBBoxes = make(map[string][]image.Rectangle, len(dataNames)) + dataObjects = make(map[string][]*events.Object, len(dataNames)) dataImages = make(map[string]*gocv.Mat, len(dataNames)) for _, dataName := range dataNames { @@ -40,6 +43,7 @@ func init() { zap.S().Panicf("unable to load data test: %v", err) } dataBBoxes[dataName] = bboxesToRectangles(bb, img.Cols(), img.Rows()) + dataObjects[dataName] = bboxesToObjects(bb) dataImages[dataName] = img } } @@ -52,6 +56,20 @@ func bboxesToRectangles(bboxes []BBox, imgWidth, imgHeiht int) []image.Rectangle return rects } +func bboxesToObjects(bboxes []BBox) []*events.Object { + objects := make([]*events.Object, 0, len(bboxes)) + for _, bb := range bboxes { + objects = append(objects, &events.Object{ + Type: events.TypeObject_ANY, + Left: bb.Left, + Top: bb.Top, + Right: bb.Right, + Bottom: bb.Bottom, + Confidence: bb.Confidence, + }) + } + return objects +} func (bb *BBox) toRect(imgWidth, imgHeight int) image.Rectangle { return image.Rect( int(bb.Left*float32(imgWidth)), @@ -228,3 +246,59 @@ func TestGroupBBoxes(t *testing.T) { }) } } +func TestGroupObjects(t *testing.T) { + type args struct { + dataName string + } + tests := []struct { + name string + args args + want []*events.Object + }{ + { + name: "groupbbox-01", + args: args{ + dataName: "01", + }, + want: []*events.Object{ + {Left: 0.26660156, Top: 0.1706543, Right: 0.5258789, Bottom: 0.47583008, Confidence: 0.4482422}, + }, + }, + { + name: "groupbbox-02", + args: args{ + dataName: "02", + }, + want: []*events.Object{ + {Left: 0.15625, Top: 0.108333334, Right: 0.6875, Bottom: 0.6666667, Confidence: -1}, + }, + }, + { + name: "groupbbox-03", + args: args{ + dataName: "03", + }, + want: []*events.Object{ + {Top: 0.14166667, Right: 0.21875, Bottom: 0.64166665, Confidence: -1}, + }, + }, + { + name: "groupbbox-04", + args: args{ + dataName: "04", + }, + want: []*events.Object{ + {Left: 0.80625, Top: 0.083333336, Right: 0.99375, Bottom: 0.53333336, Confidence: -1}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + img := dataImages[tt.args.dataName] + got := GroupObjects(dataObjects[tt.args.dataName], img.Cols(), img.Rows()) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GroupObjects() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/steering/controller.go b/pkg/steering/controller.go index 66e8acc..8a16c71 100644 --- a/pkg/steering/controller.go +++ b/pkg/steering/controller.go @@ -1,6 +1,7 @@ package steering import ( + "fmt" "github.com/cyrilix/robocar-base/service" "github.com/cyrilix/robocar-protobuf/go/events" mqtt "github.com/eclipse/paho.mqtt.golang" @@ -9,8 +10,48 @@ import ( "sync" ) -func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string) *Controller { - return &Controller{ +var ( + defaultGridMap = GridMap{ + DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, + SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, + Data: [][]float64{ + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0.25, -0.25, 0., 0.}, + {0., 0.25, 0.5, -0.5, -0.25, 0.}, + {0.25, 0.5, 1, -1, -0.5, -0.25}, + }, + } + defaultObjectFactors = GridMap{ + DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, + SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, + Data: [][]float64{ + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0.}, + {0., 0.25, 0, 0, -0.25, 0.}, + {0.5, 0.25, 0, 0, -0.5, -0.25}, + }, + } +) + +type Option func(c *Controller) + +func WithCorrector(c Corrector) Option { + return func(ctrl *Controller) { + ctrl.corrector = c + } +} + +func WithObjectsCorrectionEnabled(enabled, enabledOnUserDrive bool) Option { + return func(ctrl *Controller) { + ctrl.enableCorrection = enabled + ctrl.enableCorrectionOnUser = enabledOnUserDrive + } +} + +func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string, options ...Option) *Controller { + c := &Controller{ client: client, steeringTopic: steeringTopic, driveModeTopic: driveModeTopic, @@ -18,8 +59,12 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering tfSteeringTopic: tfSteeringTopic, objectsTopic: objectsTopic, driveMode: events.DriveMode_USER, + corrector: NewGridCorrector(), } - + for _, o := range options { + o(c) + } + return c } type Controller struct { @@ -35,26 +80,28 @@ type Controller struct { muObjects sync.RWMutex objects []*events.Object - debug bool + corrector Corrector + enableCorrection bool + enableCorrectionOnUser bool } -func (p *Controller) Start() error { - if err := registerCallbacks(p); err != nil { - zap.S().Errorf("unable to rgeister callbacks: %v", err) +func (c *Controller) Start() error { + if err := registerCallbacks(c); err != nil { + zap.S().Errorf("unable to register callbacks: %v", err) return err } - p.cancel = make(chan interface{}) - <-p.cancel + c.cancel = make(chan interface{}) + <-c.cancel return nil } -func (p *Controller) Stop() { - close(p.cancel) - service.StopService("throttle", p.client, p.driveModeTopic, p.rcSteeringTopic, p.tfSteeringTopic) +func (c *Controller) Stop() { + close(c.cancel) + service.StopService("throttle", c.client, c.driveModeTopic, c.rcSteeringTopic, c.tfSteeringTopic) } -func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { +func (c *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { var msg events.ObjectsMessage err := proto.Unmarshal(message.Payload(), &msg) if err != nil { @@ -62,12 +109,13 @@ func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { return } - p.muObjects.Lock() - defer p.muObjects.Unlock() - p.objects = msg.GetObjects() + c.muObjects.Lock() + defer c.muObjects.Unlock() + c.objects = msg.GetObjects() + zap.S().Debugf("%v object(s) received", len(c.objects)) } -func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { +func (c *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { var msg events.DriveModeMessage err := proto.Unmarshal(message.Payload(), &msg) if err != nil { @@ -75,46 +123,89 @@ func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { return } - p.muDriveMode.Lock() - defer p.muDriveMode.Unlock() - p.driveMode = msg.GetDriveMode() + c.muDriveMode.Lock() + defer c.muDriveMode.Unlock() + c.driveMode = msg.GetDriveMode() } -func (p *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) { - p.muDriveMode.RLock() - defer p.muDriveMode.RUnlock() - if p.debug { - var evt events.SteeringMessage - err := proto.Unmarshal(message.Payload(), &evt) +func (c *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) { + c.muDriveMode.RLock() + defer c.muDriveMode.RUnlock() + + if c.driveMode != events.DriveMode_USER { + return + } + + payload := message.Payload() + evt := &events.SteeringMessage{} + err := proto.Unmarshal(payload, evt) + if err != nil { + zap.S().Debugf("unable to unmarshal rc event: %v", err) + } else { + zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering()) + } + + if c.enableCorrection && c.enableCorrectionOnUser { + payload, err = c.adjustSteering(evt) if err != nil { - zap.S().Debugf("unable to unmarshal rc event: %v", err) - } else { - zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering()) + zap.S().Errorf("unable to adjust steering, skip message: %v", err) + return } } - if p.driveMode == events.DriveMode_USER { - // Republish same content - payload := message.Payload() - publish(p.client, p.steeringTopic, &payload) - } + publish(c.client, c.steeringTopic, &payload) } -func (p *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) { - p.muDriveMode.RLock() - defer p.muDriveMode.RUnlock() - if p.debug { - var evt events.SteeringMessage - err := proto.Unmarshal(message.Payload(), &evt) + +func (c *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) { + c.muDriveMode.RLock() + defer c.muDriveMode.RUnlock() + if c.driveMode != events.DriveMode_PILOT { + // User mode, skip new message + return + } + + evt := &events.SteeringMessage{} + err := proto.Unmarshal(message.Payload(), evt) + if err != nil { + zap.S().Errorf("unable to unmarshal tensorflow event: %v", err) + return + } else { + zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering()) + } + + payload := message.Payload() + if c.enableCorrection { + payload, err = c.adjustSteering(evt) if err != nil { - zap.S().Debugf("unable to unmarshal tensorflow event: %v", err) - } else { - zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering()) + zap.S().Errorf("unable to adjust steering, skip message: %v", err) + return } } - if p.driveMode == events.DriveMode_PILOT { - // Republish same content - payload := message.Payload() - publish(p.client, p.steeringTopic, &payload) + + publish(c.client, c.steeringTopic, &payload) +} + +func (c *Controller) adjustSteering(evt *events.SteeringMessage) ([]byte, error) { + steering := float64(evt.GetSteering()) + steering = c.corrector.AdjustFromObjectPosition(steering, c.Objects()) + zap.S().Debugf("adjust steering to avoid objects: %v -> %v", evt.GetSteering(), steering) + evt.Steering = float32(steering) + // override payload content + payload, err := proto.Marshal(evt) + if err != nil { + return nil, fmt.Errorf("unable to marshal steering message with new value, skip message: %v", err) } + return payload, nil +} + +func (c *Controller) Objects() []*events.Object { + c.muObjects.RLock() + defer c.muObjects.RUnlock() + res := make([]*events.Object, 0, len(c.objects)) + for _, o := range c.objects { + res = append(res, o) + } + zap.S().Debugf("copy object from %v to %v", c.objects, res) + return res } var registerCallbacks = func(p *Controller) error { diff --git a/pkg/steering/controller_test.go b/pkg/steering/controller_test.go index a7c4aec..c61f78b 100644 --- a/pkg/steering/controller_test.go +++ b/pkg/steering/controller_test.go @@ -123,3 +123,198 @@ func TestDefaultSteering(t *testing.T) { } } } + +type StaticCorrector struct { + delta float64 +} + +func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { + return s.delta +} + +func TestController_Start(t *testing.T) { + oldRegister := registerCallbacks + oldPublish := publish + defer func() { + registerCallbacks = oldRegister + publish = oldPublish + }() + registerCallbacks = func(p *Controller) error { + return nil + } + + waitPublish := sync.WaitGroup{} + var muEventsPublished sync.Mutex + eventsPublished := make(map[string][]byte) + publish = func(client mqtt.Client, topic string, payload *[]byte) { + muEventsPublished.Lock() + defer muEventsPublished.Unlock() + eventsPublished[topic] = *payload + waitPublish.Done() + } + + steeringTopic := "topic/steering" + driveModeTopic := "topic/driveMode" + rcSteeringTopic := "topic/rcSteering" + tfSteeringTopic := "topic/tfSteering" + objectsTopic := "topic/objects" + + type fields struct { + driveMode events.DriveMode + enableCorrection bool + enableCorrectionOnUser bool + } + type msgEvents struct { + driveMode events.DriveModeMessage + rcSteering events.SteeringMessage + tfSteering events.SteeringMessage + expectedSteering events.SteeringMessage + objects events.ObjectsMessage + } + + tests := []struct { + name string + fields fields + msgEvents msgEvents + correctionOnObject float64 + want events.SteeringMessage + wantErr bool + }{ + { + name: "On user drive mode, none correction", + fields: fields{ + driveMode: events.DriveMode_USER, + enableCorrection: false, + enableCorrectionOnUser: false, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + }, + { + name: "On pilot drive mode, none correction", + fields: fields{ + driveMode: events.DriveMode_PILOT, + enableCorrection: false, + enableCorrectionOnUser: false, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_PILOT}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + }, + { + name: "On pilot drive mode, correction enabled", + fields: fields{ + driveMode: events.DriveMode_PILOT, + enableCorrection: true, + enableCorrectionOnUser: false, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_PILOT}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0}, + }, + { + name: "On pilot drive mode, all corrections enabled", + fields: fields{ + driveMode: events.DriveMode_PILOT, + enableCorrection: true, + enableCorrectionOnUser: true, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_PILOT}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0}, + }, + { + name: "On user drive mode, only correction PILOT enabled", + fields: fields{ + driveMode: events.DriveMode_PILOT, + enableCorrection: true, + enableCorrectionOnUser: false, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + }, + { + name: "On user drive mode, all corrections enabled", + fields: fields{ + driveMode: events.DriveMode_USER, + enableCorrection: true, + enableCorrectionOnUser: true, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewController(nil, + steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic, + WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser), + WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}), + ) + go c.Start() + time.Sleep(1 * time.Millisecond) + + // Publish events and wait generation of new steering message + waitPublish.Add(1) + c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode)) + c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering)) + c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering)) + c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects)) + waitPublish.Wait() + + var msg events.SteeringMessage + muEventsPublished.Lock() + err := proto.Unmarshal(eventsPublished[steeringTopic], &msg) + if err != nil { + t.Errorf("unable to unmarshall response: %v", err) + t.Fail() + } + muEventsPublished.Unlock() + + if msg.GetSteering() != tt.want.GetSteering() { + t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering()) + } + + }) + } +} diff --git a/pkg/steering/corrector.go b/pkg/steering/corrector.go index f3362d1..a3508c2 100644 --- a/pkg/steering/corrector.go +++ b/pkg/steering/corrector.go @@ -8,8 +8,90 @@ import ( "os" ) -type Corrector struct { - gridMap *GridMap +type Corrector interface { + AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 +} +type OptionCorrector func(c *GridCorrector) + +func WithGridMap(configPath string) OptionCorrector { + var gm *GridMap + if configPath == "" { + zap.S().Warnf("no configuration defined for grid map, use default") + gm = &defaultGridMap + } else { + var err error + gm, err = loadConfig(configPath) + if err != nil { + zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err) + } + } + return func(c *GridCorrector) { + c.gridMap = gm + } +} + +func WithObjectMoveFactors(configPath string) OptionCorrector { + var omf *GridMap + if configPath == "" { + zap.S().Warnf("no configuration defined for objects move factors, use default") + omf = &defaultObjectFactors + } else { + var err error + omf, err = loadConfig(configPath) + if err != nil { + zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err) + } + } + return func(c *GridCorrector) { + c.objectMoveFactors = omf + } +} + +func loadConfig(configPath string) (*GridMap, error) { + content, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("unable to load grid-map config from file '%v': %w", configPath, err) + } + var gm GridMap + err = json.Unmarshal(content, &gm) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal json config '%s': %w", configPath, err) + } + return &gm, nil +} + +func WithImageSize(width, height int) OptionCorrector { + return func(c *GridCorrector) { + c.imgWidth = width + c.imgHeight = height + } +} + +func WidthDeltaMiddle(d float64) OptionCorrector { + return func(c *GridCorrector) { + c.deltaMiddle = d + } + +} +func NewGridCorrector(options ...OptionCorrector) *GridCorrector { + c := &GridCorrector{ + gridMap: &defaultGridMap, + objectMoveFactors: &defaultObjectFactors, + deltaMiddle: 0.1, + imgWidth: 160, + imgHeight: 120, + } + for _, o := range options { + o(c) + } + return c +} + +type GridCorrector struct { + gridMap *GridMap + objectMoveFactors *GridMap + deltaMiddle float64 + imgWidth, imgHeight int } /* @@ -39,45 +121,81 @@ AdjustFromObjectPosition modify steering value according object positions 3. If current steering != 0 (turn on left or right), shift right and left values proportionnaly to current steering and apply 2. -*/ -func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { - // TODO, group rectangle + : -1 -0.66 -0.33 0 0.33 0.66 1 + 0% |-----|-----|-----|-----|-----|-----| + : | 0 | 0 | 0 | 0 | 0 | 0 | + 20% |-----|-----|-----|-----|-----|-----| + : | 0.2 | 0.1 | 0 | 0 |-0.1 |-0.2 | + 40% |-----|-----|-----|-----|-----|-----| + : | ... | ... | ... | ... | ... | ... | +*/ +func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { + zap.S().Debugf("%v objects to avoid", len(objects)) if len(objects) == 0 { return currentSteering } + grpObjs := GroupObjects(objects, c.imgWidth, c.imgHeight) + // get nearest object - nearest, err := c.nearObject(objects) + nearest, err := c.nearObject(grpObjs) if err != nil { - zap.S().Warnf("unexpected error on nearest seach object, ignore objects: %v", err) + zap.S().Warnf("unexpected error on nearest search object, ignore objects: %v", err) return currentSteering } - if currentSteering > -0.1 && currentSteering < 0.1 { - - var delta float64 - - if nearest.Left < 0 && nearest.Right < 0 { - delta, err = c.gridMap.ValueOf(float64(nearest.Right)*2-1., float64(nearest.Bottom)) - } - if nearest.Left > 0 && nearest.Right > 0 { - delta, err = c.gridMap.ValueOf(float64(nearest.Left)*2-1., float64(nearest.Bottom)) - } else { - delta, err = c.gridMap.ValueOf(float64(float64(nearest.Left)+(float64(nearest.Right)-float64(nearest.Left))/2.)*2.-1., float64(nearest.Bottom)) - } + if currentSteering > -1*c.deltaMiddle && currentSteering < c.deltaMiddle { + // Straight + return currentSteering + c.computeDeviation(nearest) + } else { + // Turn to right or left, so search to avoid collision with objects on the right + // Apply factor to object to move it at middle. This factor is function of distance + factor, err := c.objectMoveFactors.ValueOf(float64(nearest.Right), float64(nearest.Bottom)) if err != nil { - zap.S().Warnf("unable to compute delta to apply to steering, skip correction: %v", err) - delta = 0 + zap.S().Warnf("unable to compute factor to apply to object: %v", err) + return currentSteering } - return currentSteering + delta + objMoved := events.Object{ + Type: nearest.Type, + Left: nearest.Left + float32(currentSteering*factor), + Top: nearest.Top, + Right: nearest.Right + float32(currentSteering*factor), + Bottom: nearest.Bottom, + Confidence: nearest.Confidence, + } + result := currentSteering + c.computeDeviation(&objMoved) + if result < -1. { + result = -1. + } + if result > 1. { + result = 1. + } + return result } - - // Search if current steering is near of Right or Left - - return currentSteering } -func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) { +func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 { + var delta float64 + var err error + + zap.S().Debugf("search delta value for bottom limit: %v", nearest.Bottom) + if nearest.Left < 0 && nearest.Right < 0 { + delta, err = c.gridMap.ValueOf(float64(nearest.Right)*2-1., float64(nearest.Bottom)) + } + if nearest.Left > 0 && nearest.Right > 0 { + delta, err = c.gridMap.ValueOf(float64(nearest.Left)*2-1., float64(nearest.Bottom)) + } else { + delta, err = c.gridMap.ValueOf(float64(float64(nearest.Left)+(float64(nearest.Right)-float64(nearest.Left))/2.)*2.-1., float64(nearest.Bottom)) + } + if err != nil { + zap.S().Warnf("unable to compute delta to apply to steering, skip correction: %v", err) + delta = 0 + } + zap.S().Debugf("new deviation computed: %v", delta) + return delta +} + +func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) { if len(objects) == 0 { return nil, fmt.Errorf("list objects must contain at least one object") } diff --git a/pkg/steering/corrector_test.go b/pkg/steering/corrector_test.go index cde7969..a3cb522 100644 --- a/pkg/steering/corrector_test.go +++ b/pkg/steering/corrector_test.go @@ -39,19 +39,21 @@ var ( Bottom: 0.9, Confidence: 0.9, } -) - -var ( - defaultGridMap = GridMap{ - DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, - SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, - Data: [][]float64{ - {0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0., 0.}, - {0., 0., 0.25, -0.25, 0., 0.}, - {0., 0.25, 0.5, -0.5, -0.25, 0.}, - {0.25, 0.5, 1, -1, -0.5, -0.25}, - }, + objectOnRightNear = events.Object{ + Type: events.TypeObject_ANY, + Left: 0.7, + Top: 0.8, + Right: 0.9, + Bottom: 0.9, + Confidence: 0.9, + } + objectOnLeftNear = events.Object{ + Type: events.TypeObject_ANY, + Left: 0.1, + Top: 0.8, + Right: 0.3, + Bottom: 0.9, + Confidence: 0.9, } ) @@ -128,7 +130,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) { currentSteering: -0.9, objects: []*events.Object{&objectOnMiddleNear}, }, - want: -0.4, + want: -1, }, { name: "run to right with 1 near object", @@ -136,16 +138,28 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) { currentSteering: 0.9, objects: []*events.Object{&objectOnMiddleNear}, }, - want: 0.4, + want: 1., + }, + { + name: "run to right with 1 near object on the right", + args: args{ + currentSteering: 0.9, + objects: []*events.Object{&objectOnRightNear}, + }, + want: 1., + }, + { + name: "run to left with 1 near object on the left", + args: args{ + currentSteering: -0.9, + objects: []*events.Object{&objectOnLeftNear}, + }, + want: -0.65, }, - - // Todo Object on left/right near/distant } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Corrector{ - gridMap: &defaultGridMap, - } + c := NewGridCorrector() if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want { t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want) } @@ -190,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Corrector{} + c := &GridCorrector{} got, err := c.nearObject(tt.args.objects) if (err != nil) != tt.wantErr { t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr) @@ -399,3 +413,67 @@ func TestGridMap_ValueOf(t *testing.T) { }) } } + +func TestWithGridMap(t *testing.T) { + type args struct { + config string + } + tests := []struct { + name string + args args + want GridMap + }{ + { + name: "default value", + args: args{config: ""}, + want: defaultGridMap, + }, + { + name: "load config", + args: args{config: "test_data/config.json"}, + want: defaultGridMap, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := GridCorrector{} + got := WithGridMap(tt.args.config) + got(&c) + if !reflect.DeepEqual(*c.gridMap, tt.want) { + t.Errorf("WithGridMap() = %v, want %v", *c.gridMap, tt.want) + } + }) + } +} + +func TestWithObjectMoveFactors(t *testing.T) { + type args struct { + config string + } + tests := []struct { + name string + args args + want GridMap + }{ + { + name: "default value", + args: args{config: ""}, + want: defaultObjectFactors, + }, + { + name: "load config", + args: args{config: "test_data/omf-config.json"}, + want: defaultObjectFactors, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := GridCorrector{} + got := WithObjectMoveFactors(tt.args.config) + got(&c) + if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) { + t.Errorf("WithObjectMoveFactors() = %v, want %v", *c.objectMoveFactors, tt.want) + } + }) + } +} diff --git a/pkg/steering/test_data/omf-config.json b/pkg/steering/test_data/omf-config.json new file mode 100644 index 0000000..1e8e671 --- /dev/null +++ b/pkg/steering/test_data/omf-config.json @@ -0,0 +1,11 @@ +{ + "steering_steps":[-1, -0.66, -0.33, 0, 0.33, 0.66, 1], + "distance_steps": [0, 0.2, 0.4, 0.6, 0.8, 1], + "data": [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0.25, 0, 0, -0.25, 0], + [0.5, 0.25, 0, 0, -0.5, -0.25] + ] +} \ No newline at end of file