diff --git a/cmd/rc-steering/rc-steering.go b/cmd/rc-steering/rc-steering.go index f11d84a..97d0e4b 100644 --- a/cmd/rc-steering/rc-steering.go +++ b/cmd/rc-steering/rc-steering.go @@ -16,6 +16,10 @@ const ( func main() { var mqttBroker, username, password, clientId string var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string + var imgWidth, imgHeight int + var enableObjectsCorrection, enableObjectsCorrectionOnUserMode bool + var gridMapConfig, objectsMoveFactorsConfig string + var deltaMiddle float64 mqttQos := cli.InitIntFlag("MQTT_QOS", 0) _, mqttRetain := os.LookupEnv("MQTT_RETAIN") @@ -27,7 +31,13 @@ func main() { flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set") flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set") flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set") - + flag.IntVar(&imgWidth, "image-width", 160, "Video pixels width") + flag.IntVar(&imgHeight, "image-height", 128, "Video pixels height") + flag.BoolVar(&enableObjectsCorrection, "enable-objects-correction", false, "Adjust steering to avoid objects") + flag.BoolVar(&enableObjectsCorrectionOnUserMode, "enable-objects-correction-user", false, "Adjust steering to avoid objects on user mode driving") + flag.StringVar(&gridMapConfig, "grid-map-config", "", "Json file path to configure grid object correction") + flag.StringVar(&objectsMoveFactorsConfig, "objects-move-factors-config", "", "Json file path to configure objects move corrections") + flag.Float64Var(&deltaMiddle, "delta-middle", 0.1, "Half Percent zone to interpret as straight") logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") flag.Parse() @@ -50,13 +60,36 @@ func main() { }() zap.ReplaceGlobals(lgr) + zap.S().Infof("steering topic : %s", steeringTopic) + zap.S().Infof("rc topic : %s", rcSteeringTopic) + zap.S().Infof("tflite steering topic : %s", tfSteeringTopic) + zap.S().Infof("drive mode topic : %s", driveModeTopic) + zap.S().Infof("objects topic : %s", objectsTopic) + zap.S().Infof("objects correction enabled : %v", enableObjectsCorrection) + zap.S().Infof("objects correction on user mode : %v", enableObjectsCorrectionOnUserMode) + zap.S().Infof("grid map file config : %v", gridMapConfig) + zap.S().Infof("objects move factors grid config: %v", objectsMoveFactorsConfig) + zap.S().Infof("image width x height : %v x %v", imgWidth, imgHeight) + client, err := cli.Connect(mqttBroker, username, password, clientId) if err != nil { log.Fatalf("unable to connect to mqtt bus: %v", err) } defer client.Disconnect(50) - p := steering.NewController(client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic) + p := steering.NewController( + client, + steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic, + steering.WithCorrector( + steering.NewCorrector( + steering.WidthDeltaMiddle(deltaMiddle), + steering.WithGridMap(gridMapConfig), + steering.WithObjectMoveFactors(objectsMoveFactorsConfig), + steering.WithImageSize(imgWidth, imgHeight), + ), + ), + steering.WithObjectsCorrectionEnabled(enableObjectsCorrection, enableObjectsCorrectionOnUserMode), + ) defer p.Stop() cli.HandleExit(p) diff --git a/pkg/steering/controller.go b/pkg/steering/controller.go index 66e8acc..2bfe15e 100644 --- a/pkg/steering/controller.go +++ b/pkg/steering/controller.go @@ -1,6 +1,7 @@ package steering import ( + "fmt" "github.com/cyrilix/robocar-base/service" "github.com/cyrilix/robocar-protobuf/go/events" mqtt "github.com/eclipse/paho.mqtt.golang" @@ -9,8 +10,48 @@ import ( "sync" ) -func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string) *Controller { - return &Controller{ +var ( + defaultGridMap = GridMap{ + DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, + SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, + Data: [][]float64{ + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0.25, -0.25, 0., 0.}, + {0., 0.25, 0.5, -0.5, -0.25, 0.}, + {0.25, 0.5, 1, -1, -0.5, -0.25}, + }, + } + defaultObjectFactors = GridMap{ + DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, + SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, + Data: [][]float64{ + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0.}, + {0., 0.25, 0, 0, -0.25, 0.}, + {0.5, 0.25, 0, 0, -0.5, -0.25}, + }, + } +) + +type Option func(c *Controller) + +func WithCorrector(c *Corrector) Option { + return func(ctrl *Controller) { + ctrl.corrector = c + } +} + +func WithObjectsCorrectionEnabled(enabled, enabledOnUserDrive bool) Option { + return func(ctrl *Controller) { + ctrl.enableCorrection = enabled + ctrl.enableCorrectionOnUser = enabledOnUserDrive + } +} + +func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string, options ...Option) *Controller { + c := &Controller{ client: client, steeringTopic: steeringTopic, driveModeTopic: driveModeTopic, @@ -18,8 +59,12 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering tfSteeringTopic: tfSteeringTopic, objectsTopic: objectsTopic, driveMode: events.DriveMode_USER, + corrector: NewCorrector(), } - + for _, o := range options { + o(c) + } + return c } type Controller struct { @@ -35,26 +80,28 @@ type Controller struct { muObjects sync.RWMutex objects []*events.Object - debug bool + corrector *Corrector + enableCorrection bool + enableCorrectionOnUser bool } -func (p *Controller) Start() error { - if err := registerCallbacks(p); err != nil { - zap.S().Errorf("unable to rgeister callbacks: %v", err) +func (c *Controller) Start() error { + if err := registerCallbacks(c); err != nil { + zap.S().Errorf("unable to register callbacks: %v", err) return err } - p.cancel = make(chan interface{}) - <-p.cancel + c.cancel = make(chan interface{}) + <-c.cancel return nil } -func (p *Controller) Stop() { - close(p.cancel) - service.StopService("throttle", p.client, p.driveModeTopic, p.rcSteeringTopic, p.tfSteeringTopic) +func (c *Controller) Stop() { + close(c.cancel) + service.StopService("throttle", c.client, c.driveModeTopic, c.rcSteeringTopic, c.tfSteeringTopic) } -func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { +func (c *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { var msg events.ObjectsMessage err := proto.Unmarshal(message.Payload(), &msg) if err != nil { @@ -62,12 +109,12 @@ func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { return } - p.muObjects.Lock() - defer p.muObjects.Unlock() - p.objects = msg.GetObjects() + c.muObjects.Lock() + defer c.muObjects.Unlock() + c.objects = msg.GetObjects() } -func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { +func (c *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { var msg events.DriveModeMessage err := proto.Unmarshal(message.Payload(), &msg) if err != nil { @@ -75,46 +122,85 @@ func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { return } - p.muDriveMode.Lock() - defer p.muDriveMode.Unlock() - p.driveMode = msg.GetDriveMode() + c.muDriveMode.Lock() + defer c.muDriveMode.Unlock() + c.driveMode = msg.GetDriveMode() } -func (p *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) { - p.muDriveMode.RLock() - defer p.muDriveMode.RUnlock() - if p.debug { - var evt events.SteeringMessage - err := proto.Unmarshal(message.Payload(), &evt) +func (c *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) { + c.muDriveMode.RLock() + defer c.muDriveMode.RUnlock() + + if c.driveMode != events.DriveMode_USER { + return + } + + payload := message.Payload() + evt := &events.SteeringMessage{} + err := proto.Unmarshal(payload, evt) + if err != nil { + zap.S().Debugf("unable to unmarshal rc event: %v", err) + } else { + zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering()) + } + + if c.enableCorrection && c.enableCorrectionOnUser { + payload, err = c.adjustSteering(evt) if err != nil { - zap.S().Debugf("unable to unmarshal rc event: %v", err) - } else { - zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering()) + zap.S().Errorf("unable to adjust steering, skip message: %v", err) + return } } - if p.driveMode == events.DriveMode_USER { - // Republish same content - payload := message.Payload() - publish(p.client, p.steeringTopic, &payload) - } + publish(c.client, c.steeringTopic, &payload) } -func (p *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) { - p.muDriveMode.RLock() - defer p.muDriveMode.RUnlock() - if p.debug { - var evt events.SteeringMessage - err := proto.Unmarshal(message.Payload(), &evt) + +func (c *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) { + c.muDriveMode.RLock() + defer c.muDriveMode.RUnlock() + if c.driveMode != events.DriveMode_PILOT { + // User mode, skip new message + return + } + + evt := &events.SteeringMessage{} + err := proto.Unmarshal(message.Payload(), evt) + if err != nil { + zap.S().Errorf("unable to unmarshal tensorflow event: %v", err) + return + } else { + zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering()) + } + + payload := message.Payload() + if c.enableCorrection { + payload, err = c.adjustSteering(evt) if err != nil { - zap.S().Debugf("unable to unmarshal tensorflow event: %v", err) - } else { - zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering()) + zap.S().Errorf("unable to adjust steering, skip message: %v", err) + return } } - if p.driveMode == events.DriveMode_PILOT { - // Republish same content - payload := message.Payload() - publish(p.client, p.steeringTopic, &payload) + + publish(c.client, c.steeringTopic, &payload) +} + +func (c *Controller) adjustSteering(evt *events.SteeringMessage) ([]byte, error) { + steering := float64(evt.GetSteering()) + steering = c.corrector.AdjustFromObjectPosition(steering, c.Objects()) + evt.Steering = float32(steering) + // override payload content + payload, err := proto.Marshal(evt) + if err != nil { + return nil, fmt.Errorf("unable to marshal steering message with new value, skip message: %v", err) } + return payload, nil +} + +func (c *Controller) Objects() []*events.Object { + c.muObjects.RLock() + defer c.muObjects.RUnlock() + res := make([]*events.Object, 0, len(c.objects)) + copy(res, c.objects) + return res } var registerCallbacks = func(p *Controller) error { diff --git a/pkg/steering/corrector.go b/pkg/steering/corrector.go index fead644..36aff05 100644 --- a/pkg/steering/corrector.go +++ b/pkg/steering/corrector.go @@ -8,14 +8,80 @@ import ( "os" ) -func NewCorrector(gridMap *GridMap, objectMoveFactors *GridMap) *Corrector { - return &Corrector{ - gridMap: gridMap, - objectMoveFactors: objectMoveFactors, +type OptionCorrector func(c *Corrector) + +func WithGridMap(configPath string) OptionCorrector { + var gm *GridMap + if configPath == "" { + zap.S().Warnf("no configuration defined for grid map, use default") + gm = &defaultGridMap + } else { + var err error + gm, err = loadConfig(configPath) + if err != nil { + zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err) + } + } + return func(c *Corrector) { + c.gridMap = gm + } +} + +func WithObjectMoveFactors(configPath string) OptionCorrector { + var omf *GridMap + if configPath == "" { + zap.S().Warnf("no configuration defined for objects move factors, use default") + omf = &defaultObjectFactors + } else { + var err error + omf, err = loadConfig(configPath) + if err != nil { + zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err) + } + } + return func(c *Corrector) { + c.objectMoveFactors = omf + } +} + +func loadConfig(configPath string) (*GridMap, error) { + content, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("unable to load grid-map config from file '%v': %w", configPath, err) + } + var gm GridMap + err = json.Unmarshal(content, &gm) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal json config '%s': %w", configPath, err) + } + return &gm, nil +} + +func WithImageSize(width, height int) OptionCorrector { + return func(c *Corrector) { + c.imgWidth = width + c.imgHeight = height + } +} + +func WidthDeltaMiddle(d float64) OptionCorrector { + return func(c *Corrector) { + c.deltaMiddle = d + } + +} +func NewCorrector(options ...OptionCorrector) *Corrector { + c := &Corrector{ + gridMap: &defaultGridMap, + objectMoveFactors: &defaultObjectFactors, deltaMiddle: 0.1, imgWidth: 160, imgHeight: 120, } + for _, o := range options { + o(c) + } + return c } type Corrector struct { diff --git a/pkg/steering/corrector_test.go b/pkg/steering/corrector_test.go index 19a9b8e..f571a8b 100644 --- a/pkg/steering/corrector_test.go +++ b/pkg/steering/corrector_test.go @@ -57,31 +57,6 @@ var ( } ) -var ( - defaultGridMap = GridMap{ - DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, - SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, - Data: [][]float64{ - {0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0., 0.}, - {0., 0., 0.25, -0.25, 0., 0.}, - {0., 0.25, 0.5, -0.5, -0.25, 0.}, - {0.25, 0.5, 1, -1, -0.5, -0.25}, - }, - } - defaultObjectFactors = GridMap{ - DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.}, - SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.}, - Data: [][]float64{ - {0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0., 0.}, - {0., 0.25, 0, 0, -0.25, 0.}, - {0.5, 0.25, 0, 0, -0.5, -0.25}, - }, - } -) - func TestCorrector_AdjustFromObjectPosition(t *testing.T) { type args struct { currentSteering float64 @@ -184,7 +159,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := NewCorrector(&defaultGridMap, &defaultObjectFactors) + c := NewCorrector() if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want { t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want) } @@ -438,3 +413,67 @@ func TestGridMap_ValueOf(t *testing.T) { }) } } + +func TestWithGridMap(t *testing.T) { + type args struct { + config string + } + tests := []struct { + name string + args args + want GridMap + }{ + { + name: "default value", + args: args{config: ""}, + want: defaultGridMap, + }, + { + name: "load config", + args: args{config: "test_data/config.json"}, + want: defaultGridMap, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Corrector{} + got := WithGridMap(tt.args.config) + got(&c) + if !reflect.DeepEqual(*c.gridMap, tt.want) { + t.Errorf("WithGridMap() = %v, want %v", *c.gridMap, tt.want) + } + }) + } +} + +func TestWithObjectMoveFactors(t *testing.T) { + type args struct { + config string + } + tests := []struct { + name string + args args + want GridMap + }{ + { + name: "default value", + args: args{config: ""}, + want: defaultObjectFactors, + }, + { + name: "load config", + args: args{config: "test_data/omf-config.json"}, + want: defaultObjectFactors, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Corrector{} + got := WithObjectMoveFactors(tt.args.config) + got(&c) + if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) { + t.Errorf("WithObjectMoveFactors() = %v, want %v", *c.objectMoveFactors, tt.want) + } + }) + } +} diff --git a/pkg/steering/test_data/omf-config.json b/pkg/steering/test_data/omf-config.json new file mode 100644 index 0000000..1e8e671 --- /dev/null +++ b/pkg/steering/test_data/omf-config.json @@ -0,0 +1,11 @@ +{ + "steering_steps":[-1, -0.66, -0.33, 0, 0.33, 0.66, 1], + "distance_steps": [0, 0.2, 0.4, 0.6, 0.8, 1], + "data": [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0.25, 0, 0, -0.25, 0], + [0.5, 0.25, 0, 0, -0.5, -0.25] + ] +} \ No newline at end of file