This commit is contained in:
Cyrille Nofficial 2022-08-28 23:33:51 +02:00
parent 86aef79e66
commit 25bea1aab3
5 changed files with 142 additions and 19 deletions

View File

@ -81,7 +81,7 @@ func main() {
client, client,
steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
steering.WithCorrector( steering.WithCorrector(
steering.NewCorrector( steering.NewGridCorrector(
steering.WidthDeltaMiddle(deltaMiddle), steering.WidthDeltaMiddle(deltaMiddle),
steering.WithGridMap(gridMapConfig), steering.WithGridMap(gridMapConfig),
steering.WithObjectMoveFactors(objectsMoveFactorsConfig), steering.WithObjectMoveFactors(objectsMoveFactorsConfig),

View File

@ -37,7 +37,7 @@ var (
type Option func(c *Controller) type Option func(c *Controller)
func WithCorrector(c *Corrector) Option { func WithCorrector(c Corrector) Option {
return func(ctrl *Controller) { return func(ctrl *Controller) {
ctrl.corrector = c ctrl.corrector = c
} }
@ -59,7 +59,7 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering
tfSteeringTopic: tfSteeringTopic, tfSteeringTopic: tfSteeringTopic,
objectsTopic: objectsTopic, objectsTopic: objectsTopic,
driveMode: events.DriveMode_USER, driveMode: events.DriveMode_USER,
corrector: NewCorrector(), corrector: NewGridCorrector(),
} }
for _, o := range options { for _, o := range options {
o(c) o(c)
@ -80,7 +80,7 @@ type Controller struct {
muObjects sync.RWMutex muObjects sync.RWMutex
objects []*events.Object objects []*events.Object
corrector *Corrector corrector Corrector
enableCorrection bool enableCorrection bool
enableCorrectionOnUser bool enableCorrectionOnUser bool
} }

View File

@ -123,3 +123,123 @@ func TestDefaultSteering(t *testing.T) {
} }
} }
} }
type StaticCorrector struct {
delta float64
}
func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
return s.delta
}
func TestController_Start(t *testing.T) {
oldRegister := registerCallbacks
oldPublish := publish
defer func() {
registerCallbacks = oldRegister
publish = oldPublish
}()
registerCallbacks = func(p *Controller) error {
return nil
}
waitPublish := sync.WaitGroup{}
var muEventsPublished sync.Mutex
eventsPublished := make(map[string][]byte)
publish = func(client mqtt.Client, topic string, payload *[]byte) {
muEventsPublished.Lock()
defer muEventsPublished.Unlock()
eventsPublished[topic] = *payload
waitPublish.Done()
}
steeringTopic := "topic/steering"
driveModeTopic := "topic/driveMode"
rcSteeringTopic := "topic/rcSteering"
tfSteeringTopic := "topic/tfSteering"
objectsTopic := "topic/objects"
type fields struct {
client mqtt.Client
steeringTopic string
muDriveMode sync.RWMutex
driveMode events.DriveMode
cancel chan interface{}
driveModeTopic string
rcSteeringTopic string
tfSteeringTopic string
objectsTopic string
muObjects sync.RWMutex
objects []*events.Object
corrector *GridCorrector
enableCorrection bool
enableCorrectionOnUser bool
}
type msgEvents struct {
driveMode events.DriveModeMessage
rcSteering events.SteeringMessage
tfSteering events.SteeringMessage
expectedSteering events.SteeringMessage
objects events.ObjectsMessage
}
tests := []struct {
name string
fields fields
msgEvents msgEvents
correctionOnObject float64
want events.SteeringMessage
wantErr bool
}{
{
name: "On user drive mode, none correction",
fields: fields{
driveMode: events.DriveMode_USER,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewController(nil,
steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser),
WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}),
)
go c.Start()
defer c.Stop()
// Publish events and wait generation of new steering message
waitPublish.Add(1)
c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode))
c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering))
c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering))
c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects))
waitPublish.Wait()
var msg events.SteeringMessage
muEventsPublished.Lock()
err := proto.Unmarshal(eventsPublished[steeringTopic], &msg)
if err != nil {
t.Errorf("unable to unmarshall response: %v", err)
t.Fail()
}
muEventsPublished.Unlock()
if msg.GetSteering() != tt.want.GetSteering() {
t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering())
}
})
}
}

View File

@ -8,7 +8,10 @@ import (
"os" "os"
) )
type OptionCorrector func(c *Corrector) type Corrector interface {
AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64
}
type OptionCorrector func(c *GridCorrector)
func WithGridMap(configPath string) OptionCorrector { func WithGridMap(configPath string) OptionCorrector {
var gm *GridMap var gm *GridMap
@ -22,7 +25,7 @@ func WithGridMap(configPath string) OptionCorrector {
zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err) zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err)
} }
} }
return func(c *Corrector) { return func(c *GridCorrector) {
c.gridMap = gm c.gridMap = gm
} }
} }
@ -39,7 +42,7 @@ func WithObjectMoveFactors(configPath string) OptionCorrector {
zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err) zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err)
} }
} }
return func(c *Corrector) { return func(c *GridCorrector) {
c.objectMoveFactors = omf c.objectMoveFactors = omf
} }
} }
@ -58,20 +61,20 @@ func loadConfig(configPath string) (*GridMap, error) {
} }
func WithImageSize(width, height int) OptionCorrector { func WithImageSize(width, height int) OptionCorrector {
return func(c *Corrector) { return func(c *GridCorrector) {
c.imgWidth = width c.imgWidth = width
c.imgHeight = height c.imgHeight = height
} }
} }
func WidthDeltaMiddle(d float64) OptionCorrector { func WidthDeltaMiddle(d float64) OptionCorrector {
return func(c *Corrector) { return func(c *GridCorrector) {
c.deltaMiddle = d c.deltaMiddle = d
} }
} }
func NewCorrector(options ...OptionCorrector) *Corrector { func NewGridCorrector(options ...OptionCorrector) *GridCorrector {
c := &Corrector{ c := &GridCorrector{
gridMap: &defaultGridMap, gridMap: &defaultGridMap,
objectMoveFactors: &defaultObjectFactors, objectMoveFactors: &defaultObjectFactors,
deltaMiddle: 0.1, deltaMiddle: 0.1,
@ -84,7 +87,7 @@ func NewCorrector(options ...OptionCorrector) *Corrector {
return c return c
} }
type Corrector struct { type GridCorrector struct {
gridMap *GridMap gridMap *GridMap
objectMoveFactors *GridMap objectMoveFactors *GridMap
deltaMiddle float64 deltaMiddle float64
@ -127,7 +130,7 @@ AdjustFromObjectPosition modify steering value according object positions
40% |-----|-----|-----|-----|-----|-----| 40% |-----|-----|-----|-----|-----|-----|
: | ... | ... | ... | ... | ... | ... | : | ... | ... | ... | ... | ... | ... |
*/ */
func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
if len(objects) == 0 { if len(objects) == 0 {
return currentSteering return currentSteering
} }
@ -170,7 +173,7 @@ func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []
} }
} }
func (c *Corrector) computeDeviation(nearest *events.Object) float64 { func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 {
var delta float64 var delta float64
var err error var err error
@ -189,7 +192,7 @@ func (c *Corrector) computeDeviation(nearest *events.Object) float64 {
return delta return delta
} }
func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) { func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) {
if len(objects) == 0 { if len(objects) == 0 {
return nil, fmt.Errorf("list objects must contain at least one object") return nil, fmt.Errorf("list objects must contain at least one object")
} }

View File

@ -159,7 +159,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := NewCorrector() c := NewGridCorrector()
if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want { if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want {
t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want) t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want)
} }
@ -204,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Corrector{} c := &GridCorrector{}
got, err := c.nearObject(tt.args.objects) got, err := c.nearObject(tt.args.objects)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr)
@ -436,7 +436,7 @@ func TestWithGridMap(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := Corrector{} c := GridCorrector{}
got := WithGridMap(tt.args.config) got := WithGridMap(tt.args.config)
got(&c) got(&c)
if !reflect.DeepEqual(*c.gridMap, tt.want) { if !reflect.DeepEqual(*c.gridMap, tt.want) {
@ -468,7 +468,7 @@ func TestWithObjectMoveFactors(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := Corrector{} c := GridCorrector{}
got := WithObjectMoveFactors(tt.args.config) got := WithObjectMoveFactors(tt.args.config)
got(&c) got(&c)
if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) { if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) {