WIP
This commit is contained in:
		@@ -81,7 +81,7 @@ func main() {
 | 
			
		||||
		client,
 | 
			
		||||
		steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
 | 
			
		||||
		steering.WithCorrector(
 | 
			
		||||
			steering.NewCorrector(
 | 
			
		||||
			steering.NewGridCorrector(
 | 
			
		||||
				steering.WidthDeltaMiddle(deltaMiddle),
 | 
			
		||||
				steering.WithGridMap(gridMapConfig),
 | 
			
		||||
				steering.WithObjectMoveFactors(objectsMoveFactorsConfig),
 | 
			
		||||
 
 | 
			
		||||
@@ -37,7 +37,7 @@ var (
 | 
			
		||||
 | 
			
		||||
type Option func(c *Controller)
 | 
			
		||||
 | 
			
		||||
func WithCorrector(c *Corrector) Option {
 | 
			
		||||
func WithCorrector(c Corrector) Option {
 | 
			
		||||
	return func(ctrl *Controller) {
 | 
			
		||||
		ctrl.corrector = c
 | 
			
		||||
	}
 | 
			
		||||
@@ -59,7 +59,7 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering
 | 
			
		||||
		tfSteeringTopic: tfSteeringTopic,
 | 
			
		||||
		objectsTopic:    objectsTopic,
 | 
			
		||||
		driveMode:       events.DriveMode_USER,
 | 
			
		||||
		corrector:       NewCorrector(),
 | 
			
		||||
		corrector:       NewGridCorrector(),
 | 
			
		||||
	}
 | 
			
		||||
	for _, o := range options {
 | 
			
		||||
		o(c)
 | 
			
		||||
@@ -80,7 +80,7 @@ type Controller struct {
 | 
			
		||||
	muObjects sync.RWMutex
 | 
			
		||||
	objects   []*events.Object
 | 
			
		||||
 | 
			
		||||
	corrector              *Corrector
 | 
			
		||||
	corrector              Corrector
 | 
			
		||||
	enableCorrection       bool
 | 
			
		||||
	enableCorrectionOnUser bool
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -123,3 +123,123 @@ func TestDefaultSteering(t *testing.T) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StaticCorrector struct {
 | 
			
		||||
	delta float64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
 | 
			
		||||
	return s.delta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestController_Start(t *testing.T) {
 | 
			
		||||
	oldRegister := registerCallbacks
 | 
			
		||||
	oldPublish := publish
 | 
			
		||||
	defer func() {
 | 
			
		||||
		registerCallbacks = oldRegister
 | 
			
		||||
		publish = oldPublish
 | 
			
		||||
	}()
 | 
			
		||||
	registerCallbacks = func(p *Controller) error {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	waitPublish := sync.WaitGroup{}
 | 
			
		||||
	var muEventsPublished sync.Mutex
 | 
			
		||||
	eventsPublished := make(map[string][]byte)
 | 
			
		||||
	publish = func(client mqtt.Client, topic string, payload *[]byte) {
 | 
			
		||||
		muEventsPublished.Lock()
 | 
			
		||||
		defer muEventsPublished.Unlock()
 | 
			
		||||
		eventsPublished[topic] = *payload
 | 
			
		||||
		waitPublish.Done()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	steeringTopic := "topic/steering"
 | 
			
		||||
	driveModeTopic := "topic/driveMode"
 | 
			
		||||
	rcSteeringTopic := "topic/rcSteering"
 | 
			
		||||
	tfSteeringTopic := "topic/tfSteering"
 | 
			
		||||
	objectsTopic := "topic/objects"
 | 
			
		||||
 | 
			
		||||
	type fields struct {
 | 
			
		||||
		client                 mqtt.Client
 | 
			
		||||
		steeringTopic          string
 | 
			
		||||
		muDriveMode            sync.RWMutex
 | 
			
		||||
		driveMode              events.DriveMode
 | 
			
		||||
		cancel                 chan interface{}
 | 
			
		||||
		driveModeTopic         string
 | 
			
		||||
		rcSteeringTopic        string
 | 
			
		||||
		tfSteeringTopic        string
 | 
			
		||||
		objectsTopic           string
 | 
			
		||||
		muObjects              sync.RWMutex
 | 
			
		||||
		objects                []*events.Object
 | 
			
		||||
		corrector              *GridCorrector
 | 
			
		||||
		enableCorrection       bool
 | 
			
		||||
		enableCorrectionOnUser bool
 | 
			
		||||
	}
 | 
			
		||||
	type msgEvents struct {
 | 
			
		||||
		driveMode        events.DriveModeMessage
 | 
			
		||||
		rcSteering       events.SteeringMessage
 | 
			
		||||
		tfSteering       events.SteeringMessage
 | 
			
		||||
		expectedSteering events.SteeringMessage
 | 
			
		||||
		objects          events.ObjectsMessage
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name               string
 | 
			
		||||
		fields             fields
 | 
			
		||||
		msgEvents          msgEvents
 | 
			
		||||
		correctionOnObject float64
 | 
			
		||||
		want               events.SteeringMessage
 | 
			
		||||
		wantErr            bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "On user drive mode, none correction",
 | 
			
		||||
 | 
			
		||||
			fields: fields{
 | 
			
		||||
				driveMode: events.DriveMode_USER,
 | 
			
		||||
			},
 | 
			
		||||
			msgEvents: msgEvents{
 | 
			
		||||
				driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_USER},
 | 
			
		||||
				rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
			
		||||
				tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
			
		||||
				objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
			
		||||
			},
 | 
			
		||||
			correctionOnObject: 0.5,
 | 
			
		||||
			// Get rc value without correction
 | 
			
		||||
			want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			c := NewController(nil,
 | 
			
		||||
				steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
 | 
			
		||||
				WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser),
 | 
			
		||||
				WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}),
 | 
			
		||||
			)
 | 
			
		||||
			go c.Start()
 | 
			
		||||
			defer c.Stop()
 | 
			
		||||
 | 
			
		||||
			// Publish events and wait generation of new steering message
 | 
			
		||||
			waitPublish.Add(1)
 | 
			
		||||
			c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode))
 | 
			
		||||
			c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering))
 | 
			
		||||
			c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering))
 | 
			
		||||
			c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects))
 | 
			
		||||
			waitPublish.Wait()
 | 
			
		||||
 | 
			
		||||
			var msg events.SteeringMessage
 | 
			
		||||
			muEventsPublished.Lock()
 | 
			
		||||
			err := proto.Unmarshal(eventsPublished[steeringTopic], &msg)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Errorf("unable to unmarshall response: %v", err)
 | 
			
		||||
				t.Fail()
 | 
			
		||||
			}
 | 
			
		||||
			muEventsPublished.Unlock()
 | 
			
		||||
 | 
			
		||||
			if msg.GetSteering() != tt.want.GetSteering() {
 | 
			
		||||
				t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,10 @@ import (
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OptionCorrector func(c *Corrector)
 | 
			
		||||
type Corrector interface {
 | 
			
		||||
	AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64
 | 
			
		||||
}
 | 
			
		||||
type OptionCorrector func(c *GridCorrector)
 | 
			
		||||
 | 
			
		||||
func WithGridMap(configPath string) OptionCorrector {
 | 
			
		||||
	var gm *GridMap
 | 
			
		||||
@@ -22,7 +25,7 @@ func WithGridMap(configPath string) OptionCorrector {
 | 
			
		||||
			zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return func(c *Corrector) {
 | 
			
		||||
	return func(c *GridCorrector) {
 | 
			
		||||
		c.gridMap = gm
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -39,7 +42,7 @@ func WithObjectMoveFactors(configPath string) OptionCorrector {
 | 
			
		||||
			zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return func(c *Corrector) {
 | 
			
		||||
	return func(c *GridCorrector) {
 | 
			
		||||
		c.objectMoveFactors = omf
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -58,20 +61,20 @@ func loadConfig(configPath string) (*GridMap, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WithImageSize(width, height int) OptionCorrector {
 | 
			
		||||
	return func(c *Corrector) {
 | 
			
		||||
	return func(c *GridCorrector) {
 | 
			
		||||
		c.imgWidth = width
 | 
			
		||||
		c.imgHeight = height
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WidthDeltaMiddle(d float64) OptionCorrector {
 | 
			
		||||
	return func(c *Corrector) {
 | 
			
		||||
	return func(c *GridCorrector) {
 | 
			
		||||
		c.deltaMiddle = d
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
func NewCorrector(options ...OptionCorrector) *Corrector {
 | 
			
		||||
	c := &Corrector{
 | 
			
		||||
func NewGridCorrector(options ...OptionCorrector) *GridCorrector {
 | 
			
		||||
	c := &GridCorrector{
 | 
			
		||||
		gridMap:           &defaultGridMap,
 | 
			
		||||
		objectMoveFactors: &defaultObjectFactors,
 | 
			
		||||
		deltaMiddle:       0.1,
 | 
			
		||||
@@ -84,7 +87,7 @@ func NewCorrector(options ...OptionCorrector) *Corrector {
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Corrector struct {
 | 
			
		||||
type GridCorrector struct {
 | 
			
		||||
	gridMap             *GridMap
 | 
			
		||||
	objectMoveFactors   *GridMap
 | 
			
		||||
	deltaMiddle         float64
 | 
			
		||||
@@ -127,7 +130,7 @@ AdjustFromObjectPosition modify steering value according object positions
 | 
			
		||||
    40% |-----|-----|-----|-----|-----|-----|
 | 
			
		||||
    :   | ... | ... | ... | ... | ... | ... |
 | 
			
		||||
*/
 | 
			
		||||
func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
 | 
			
		||||
func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
 | 
			
		||||
	if len(objects) == 0 {
 | 
			
		||||
		return currentSteering
 | 
			
		||||
	}
 | 
			
		||||
@@ -170,7 +173,7 @@ func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Corrector) computeDeviation(nearest *events.Object) float64 {
 | 
			
		||||
func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 {
 | 
			
		||||
	var delta float64
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
@@ -189,7 +192,7 @@ func (c *Corrector) computeDeviation(nearest *events.Object) float64 {
 | 
			
		||||
	return delta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) {
 | 
			
		||||
func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) {
 | 
			
		||||
	if len(objects) == 0 {
 | 
			
		||||
		return nil, fmt.Errorf("list objects must contain at least one object")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -159,7 +159,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			c := NewCorrector()
 | 
			
		||||
			c := NewGridCorrector()
 | 
			
		||||
			if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want {
 | 
			
		||||
				t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want)
 | 
			
		||||
			}
 | 
			
		||||
@@ -204,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			c := &Corrector{}
 | 
			
		||||
			c := &GridCorrector{}
 | 
			
		||||
			got, err := c.nearObject(tt.args.objects)
 | 
			
		||||
			if (err != nil) != tt.wantErr {
 | 
			
		||||
				t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr)
 | 
			
		||||
@@ -436,7 +436,7 @@ func TestWithGridMap(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			c := Corrector{}
 | 
			
		||||
			c := GridCorrector{}
 | 
			
		||||
			got := WithGridMap(tt.args.config)
 | 
			
		||||
			got(&c)
 | 
			
		||||
			if !reflect.DeepEqual(*c.gridMap, tt.want) {
 | 
			
		||||
@@ -468,7 +468,7 @@ func TestWithObjectMoveFactors(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			c := Corrector{}
 | 
			
		||||
			c := GridCorrector{}
 | 
			
		||||
			got := WithObjectMoveFactors(tt.args.config)
 | 
			
		||||
			got(&c)
 | 
			
		||||
			if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user