diff --git a/cmd/rc-steering/rc-steering.go b/cmd/rc-steering/rc-steering.go index 97d0e4b..c6b44a8 100644 --- a/cmd/rc-steering/rc-steering.go +++ b/cmd/rc-steering/rc-steering.go @@ -81,7 +81,7 @@ func main() { client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic, steering.WithCorrector( - steering.NewCorrector( + steering.NewGridCorrector( steering.WidthDeltaMiddle(deltaMiddle), steering.WithGridMap(gridMapConfig), steering.WithObjectMoveFactors(objectsMoveFactorsConfig), diff --git a/pkg/steering/controller.go b/pkg/steering/controller.go index 2bfe15e..6cec875 100644 --- a/pkg/steering/controller.go +++ b/pkg/steering/controller.go @@ -37,7 +37,7 @@ var ( type Option func(c *Controller) -func WithCorrector(c *Corrector) Option { +func WithCorrector(c Corrector) Option { return func(ctrl *Controller) { ctrl.corrector = c } @@ -59,7 +59,7 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering tfSteeringTopic: tfSteeringTopic, objectsTopic: objectsTopic, driveMode: events.DriveMode_USER, - corrector: NewCorrector(), + corrector: NewGridCorrector(), } for _, o := range options { o(c) @@ -80,7 +80,7 @@ type Controller struct { muObjects sync.RWMutex objects []*events.Object - corrector *Corrector + corrector Corrector enableCorrection bool enableCorrectionOnUser bool } diff --git a/pkg/steering/controller_test.go b/pkg/steering/controller_test.go index a7c4aec..b2559c0 100644 --- a/pkg/steering/controller_test.go +++ b/pkg/steering/controller_test.go @@ -123,3 +123,123 @@ func TestDefaultSteering(t *testing.T) { } } } + +type StaticCorrector struct { + delta float64 +} + +func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { + return s.delta +} + +func TestController_Start(t *testing.T) { + oldRegister := registerCallbacks + oldPublish := publish + defer func() { + registerCallbacks = oldRegister + publish = oldPublish + }() + registerCallbacks = func(p *Controller) error { + return nil + } + + waitPublish := sync.WaitGroup{} + var muEventsPublished sync.Mutex + eventsPublished := make(map[string][]byte) + publish = func(client mqtt.Client, topic string, payload *[]byte) { + muEventsPublished.Lock() + defer muEventsPublished.Unlock() + eventsPublished[topic] = *payload + waitPublish.Done() + } + + steeringTopic := "topic/steering" + driveModeTopic := "topic/driveMode" + rcSteeringTopic := "topic/rcSteering" + tfSteeringTopic := "topic/tfSteering" + objectsTopic := "topic/objects" + + type fields struct { + client mqtt.Client + steeringTopic string + muDriveMode sync.RWMutex + driveMode events.DriveMode + cancel chan interface{} + driveModeTopic string + rcSteeringTopic string + tfSteeringTopic string + objectsTopic string + muObjects sync.RWMutex + objects []*events.Object + corrector *GridCorrector + enableCorrection bool + enableCorrectionOnUser bool + } + type msgEvents struct { + driveMode events.DriveModeMessage + rcSteering events.SteeringMessage + tfSteering events.SteeringMessage + expectedSteering events.SteeringMessage + objects events.ObjectsMessage + } + + tests := []struct { + name string + fields fields + msgEvents msgEvents + correctionOnObject float64 + want events.SteeringMessage + wantErr bool + }{ + { + name: "On user drive mode, none correction", + + fields: fields{ + driveMode: events.DriveMode_USER, + }, + msgEvents: msgEvents{ + driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER}, + rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0}, + objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}}, + }, + correctionOnObject: 0.5, + // Get rc value without correction + want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewController(nil, + steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic, + WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser), + WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}), + ) + go c.Start() + defer c.Stop() + + // Publish events and wait generation of new steering message + waitPublish.Add(1) + c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode)) + c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering)) + c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering)) + c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects)) + waitPublish.Wait() + + var msg events.SteeringMessage + muEventsPublished.Lock() + err := proto.Unmarshal(eventsPublished[steeringTopic], &msg) + if err != nil { + t.Errorf("unable to unmarshall response: %v", err) + t.Fail() + } + muEventsPublished.Unlock() + + if msg.GetSteering() != tt.want.GetSteering() { + t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering()) + } + + }) + } +} diff --git a/pkg/steering/corrector.go b/pkg/steering/corrector.go index 36aff05..3eff2bf 100644 --- a/pkg/steering/corrector.go +++ b/pkg/steering/corrector.go @@ -8,7 +8,10 @@ import ( "os" ) -type OptionCorrector func(c *Corrector) +type Corrector interface { + AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 +} +type OptionCorrector func(c *GridCorrector) func WithGridMap(configPath string) OptionCorrector { var gm *GridMap @@ -22,7 +25,7 @@ func WithGridMap(configPath string) OptionCorrector { zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err) } } - return func(c *Corrector) { + return func(c *GridCorrector) { c.gridMap = gm } } @@ -39,7 +42,7 @@ func WithObjectMoveFactors(configPath string) OptionCorrector { zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err) } } - return func(c *Corrector) { + return func(c *GridCorrector) { c.objectMoveFactors = omf } } @@ -58,20 +61,20 @@ func loadConfig(configPath string) (*GridMap, error) { } func WithImageSize(width, height int) OptionCorrector { - return func(c *Corrector) { + return func(c *GridCorrector) { c.imgWidth = width c.imgHeight = height } } func WidthDeltaMiddle(d float64) OptionCorrector { - return func(c *Corrector) { + return func(c *GridCorrector) { c.deltaMiddle = d } } -func NewCorrector(options ...OptionCorrector) *Corrector { - c := &Corrector{ +func NewGridCorrector(options ...OptionCorrector) *GridCorrector { + c := &GridCorrector{ gridMap: &defaultGridMap, objectMoveFactors: &defaultObjectFactors, deltaMiddle: 0.1, @@ -84,7 +87,7 @@ func NewCorrector(options ...OptionCorrector) *Corrector { return c } -type Corrector struct { +type GridCorrector struct { gridMap *GridMap objectMoveFactors *GridMap deltaMiddle float64 @@ -127,7 +130,7 @@ AdjustFromObjectPosition modify steering value according object positions 40% |-----|-----|-----|-----|-----|-----| : | ... | ... | ... | ... | ... | ... | */ -func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { +func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { if len(objects) == 0 { return currentSteering } @@ -170,7 +173,7 @@ func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects [] } } -func (c *Corrector) computeDeviation(nearest *events.Object) float64 { +func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 { var delta float64 var err error @@ -189,7 +192,7 @@ func (c *Corrector) computeDeviation(nearest *events.Object) float64 { return delta } -func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) { +func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) { if len(objects) == 0 { return nil, fmt.Errorf("list objects must contain at least one object") } diff --git a/pkg/steering/corrector_test.go b/pkg/steering/corrector_test.go index f571a8b..a3cb522 100644 --- a/pkg/steering/corrector_test.go +++ b/pkg/steering/corrector_test.go @@ -159,7 +159,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := NewCorrector() + c := NewGridCorrector() if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want { t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want) } @@ -204,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Corrector{} + c := &GridCorrector{} got, err := c.nearObject(tt.args.objects) if (err != nil) != tt.wantErr { t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr) @@ -436,7 +436,7 @@ func TestWithGridMap(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := Corrector{} + c := GridCorrector{} got := WithGridMap(tt.args.config) got(&c) if !reflect.DeepEqual(*c.gridMap, tt.want) { @@ -468,7 +468,7 @@ func TestWithObjectMoveFactors(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := Corrector{} + c := GridCorrector{} got := WithObjectMoveFactors(tt.args.config) got(&c) if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) {