Compare commits

...

2 Commits

Author SHA1 Message Date
25bea1aab3 WIP 2022-08-28 23:33:51 +02:00
86aef79e66 WIP 2022-08-28 12:06:43 +02:00
6 changed files with 442 additions and 84 deletions

View File

@ -16,6 +16,10 @@ const (
func main() { func main() {
var mqttBroker, username, password, clientId string var mqttBroker, username, password, clientId string
var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string
var imgWidth, imgHeight int
var enableObjectsCorrection, enableObjectsCorrectionOnUserMode bool
var gridMapConfig, objectsMoveFactorsConfig string
var deltaMiddle float64
mqttQos := cli.InitIntFlag("MQTT_QOS", 0) mqttQos := cli.InitIntFlag("MQTT_QOS", 0)
_, mqttRetain := os.LookupEnv("MQTT_RETAIN") _, mqttRetain := os.LookupEnv("MQTT_RETAIN")
@ -27,7 +31,13 @@ func main() {
flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set") flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set")
flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set") flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set")
flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set") flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set")
flag.IntVar(&imgWidth, "image-width", 160, "Video pixels width")
flag.IntVar(&imgHeight, "image-height", 128, "Video pixels height")
flag.BoolVar(&enableObjectsCorrection, "enable-objects-correction", false, "Adjust steering to avoid objects")
flag.BoolVar(&enableObjectsCorrectionOnUserMode, "enable-objects-correction-user", false, "Adjust steering to avoid objects on user mode driving")
flag.StringVar(&gridMapConfig, "grid-map-config", "", "Json file path to configure grid object correction")
flag.StringVar(&objectsMoveFactorsConfig, "objects-move-factors-config", "", "Json file path to configure objects move corrections")
flag.Float64Var(&deltaMiddle, "delta-middle", 0.1, "Half Percent zone to interpret as straight")
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level") logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
flag.Parse() flag.Parse()
@ -50,13 +60,36 @@ func main() {
}() }()
zap.ReplaceGlobals(lgr) zap.ReplaceGlobals(lgr)
zap.S().Infof("steering topic : %s", steeringTopic)
zap.S().Infof("rc topic : %s", rcSteeringTopic)
zap.S().Infof("tflite steering topic : %s", tfSteeringTopic)
zap.S().Infof("drive mode topic : %s", driveModeTopic)
zap.S().Infof("objects topic : %s", objectsTopic)
zap.S().Infof("objects correction enabled : %v", enableObjectsCorrection)
zap.S().Infof("objects correction on user mode : %v", enableObjectsCorrectionOnUserMode)
zap.S().Infof("grid map file config : %v", gridMapConfig)
zap.S().Infof("objects move factors grid config: %v", objectsMoveFactorsConfig)
zap.S().Infof("image width x height : %v x %v", imgWidth, imgHeight)
client, err := cli.Connect(mqttBroker, username, password, clientId) client, err := cli.Connect(mqttBroker, username, password, clientId)
if err != nil { if err != nil {
log.Fatalf("unable to connect to mqtt bus: %v", err) log.Fatalf("unable to connect to mqtt bus: %v", err)
} }
defer client.Disconnect(50) defer client.Disconnect(50)
p := steering.NewController(client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic) p := steering.NewController(
client,
steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
steering.WithCorrector(
steering.NewGridCorrector(
steering.WidthDeltaMiddle(deltaMiddle),
steering.WithGridMap(gridMapConfig),
steering.WithObjectMoveFactors(objectsMoveFactorsConfig),
steering.WithImageSize(imgWidth, imgHeight),
),
),
steering.WithObjectsCorrectionEnabled(enableObjectsCorrection, enableObjectsCorrectionOnUserMode),
)
defer p.Stop() defer p.Stop()
cli.HandleExit(p) cli.HandleExit(p)

View File

@ -1,6 +1,7 @@
package steering package steering
import ( import (
"fmt"
"github.com/cyrilix/robocar-base/service" "github.com/cyrilix/robocar-base/service"
"github.com/cyrilix/robocar-protobuf/go/events" "github.com/cyrilix/robocar-protobuf/go/events"
mqtt "github.com/eclipse/paho.mqtt.golang" mqtt "github.com/eclipse/paho.mqtt.golang"
@ -9,8 +10,48 @@ import (
"sync" "sync"
) )
func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string) *Controller { var (
return &Controller{ defaultGridMap = GridMap{
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
Data: [][]float64{
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0., 0.25, -0.25, 0., 0.},
{0., 0.25, 0.5, -0.5, -0.25, 0.},
{0.25, 0.5, 1, -1, -0.5, -0.25},
},
}
defaultObjectFactors = GridMap{
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
Data: [][]float64{
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0.25, 0, 0, -0.25, 0.},
{0.5, 0.25, 0, 0, -0.5, -0.25},
},
}
)
type Option func(c *Controller)
func WithCorrector(c Corrector) Option {
return func(ctrl *Controller) {
ctrl.corrector = c
}
}
func WithObjectsCorrectionEnabled(enabled, enabledOnUserDrive bool) Option {
return func(ctrl *Controller) {
ctrl.enableCorrection = enabled
ctrl.enableCorrectionOnUser = enabledOnUserDrive
}
}
func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string, options ...Option) *Controller {
c := &Controller{
client: client, client: client,
steeringTopic: steeringTopic, steeringTopic: steeringTopic,
driveModeTopic: driveModeTopic, driveModeTopic: driveModeTopic,
@ -18,8 +59,12 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering
tfSteeringTopic: tfSteeringTopic, tfSteeringTopic: tfSteeringTopic,
objectsTopic: objectsTopic, objectsTopic: objectsTopic,
driveMode: events.DriveMode_USER, driveMode: events.DriveMode_USER,
corrector: NewGridCorrector(),
} }
for _, o := range options {
o(c)
}
return c
} }
type Controller struct { type Controller struct {
@ -35,26 +80,28 @@ type Controller struct {
muObjects sync.RWMutex muObjects sync.RWMutex
objects []*events.Object objects []*events.Object
debug bool corrector Corrector
enableCorrection bool
enableCorrectionOnUser bool
} }
func (p *Controller) Start() error { func (c *Controller) Start() error {
if err := registerCallbacks(p); err != nil { if err := registerCallbacks(c); err != nil {
zap.S().Errorf("unable to rgeister callbacks: %v", err) zap.S().Errorf("unable to register callbacks: %v", err)
return err return err
} }
p.cancel = make(chan interface{}) c.cancel = make(chan interface{})
<-p.cancel <-c.cancel
return nil return nil
} }
func (p *Controller) Stop() { func (c *Controller) Stop() {
close(p.cancel) close(c.cancel)
service.StopService("throttle", p.client, p.driveModeTopic, p.rcSteeringTopic, p.tfSteeringTopic) service.StopService("throttle", c.client, c.driveModeTopic, c.rcSteeringTopic, c.tfSteeringTopic)
} }
func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) { func (c *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
var msg events.ObjectsMessage var msg events.ObjectsMessage
err := proto.Unmarshal(message.Payload(), &msg) err := proto.Unmarshal(message.Payload(), &msg)
if err != nil { if err != nil {
@ -62,12 +109,12 @@ func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
return return
} }
p.muObjects.Lock() c.muObjects.Lock()
defer p.muObjects.Unlock() defer c.muObjects.Unlock()
p.objects = msg.GetObjects() c.objects = msg.GetObjects()
} }
func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) { func (c *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
var msg events.DriveModeMessage var msg events.DriveModeMessage
err := proto.Unmarshal(message.Payload(), &msg) err := proto.Unmarshal(message.Payload(), &msg)
if err != nil { if err != nil {
@ -75,46 +122,85 @@ func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
return return
} }
p.muDriveMode.Lock() c.muDriveMode.Lock()
defer p.muDriveMode.Unlock() defer c.muDriveMode.Unlock()
p.driveMode = msg.GetDriveMode() c.driveMode = msg.GetDriveMode()
} }
func (p *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) { func (c *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) {
p.muDriveMode.RLock() c.muDriveMode.RLock()
defer p.muDriveMode.RUnlock() defer c.muDriveMode.RUnlock()
if p.debug {
var evt events.SteeringMessage if c.driveMode != events.DriveMode_USER {
err := proto.Unmarshal(message.Payload(), &evt) return
}
payload := message.Payload()
evt := &events.SteeringMessage{}
err := proto.Unmarshal(payload, evt)
if err != nil {
zap.S().Debugf("unable to unmarshal rc event: %v", err)
} else {
zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering())
}
if c.enableCorrection && c.enableCorrectionOnUser {
payload, err = c.adjustSteering(evt)
if err != nil { if err != nil {
zap.S().Debugf("unable to unmarshal rc event: %v", err) zap.S().Errorf("unable to adjust steering, skip message: %v", err)
} else { return
zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering())
} }
} }
if p.driveMode == events.DriveMode_USER { publish(c.client, c.steeringTopic, &payload)
// Republish same content
payload := message.Payload()
publish(p.client, p.steeringTopic, &payload)
}
} }
func (p *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) {
p.muDriveMode.RLock() func (c *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) {
defer p.muDriveMode.RUnlock() c.muDriveMode.RLock()
if p.debug { defer c.muDriveMode.RUnlock()
var evt events.SteeringMessage if c.driveMode != events.DriveMode_PILOT {
err := proto.Unmarshal(message.Payload(), &evt) // User mode, skip new message
return
}
evt := &events.SteeringMessage{}
err := proto.Unmarshal(message.Payload(), evt)
if err != nil {
zap.S().Errorf("unable to unmarshal tensorflow event: %v", err)
return
} else {
zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering())
}
payload := message.Payload()
if c.enableCorrection {
payload, err = c.adjustSteering(evt)
if err != nil { if err != nil {
zap.S().Debugf("unable to unmarshal tensorflow event: %v", err) zap.S().Errorf("unable to adjust steering, skip message: %v", err)
} else { return
zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering())
} }
} }
if p.driveMode == events.DriveMode_PILOT {
// Republish same content publish(c.client, c.steeringTopic, &payload)
payload := message.Payload() }
publish(p.client, p.steeringTopic, &payload)
func (c *Controller) adjustSteering(evt *events.SteeringMessage) ([]byte, error) {
steering := float64(evt.GetSteering())
steering = c.corrector.AdjustFromObjectPosition(steering, c.Objects())
evt.Steering = float32(steering)
// override payload content
payload, err := proto.Marshal(evt)
if err != nil {
return nil, fmt.Errorf("unable to marshal steering message with new value, skip message: %v", err)
} }
return payload, nil
}
func (c *Controller) Objects() []*events.Object {
c.muObjects.RLock()
defer c.muObjects.RUnlock()
res := make([]*events.Object, 0, len(c.objects))
copy(res, c.objects)
return res
} }
var registerCallbacks = func(p *Controller) error { var registerCallbacks = func(p *Controller) error {

View File

@ -123,3 +123,123 @@ func TestDefaultSteering(t *testing.T) {
} }
} }
} }
type StaticCorrector struct {
delta float64
}
func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
return s.delta
}
func TestController_Start(t *testing.T) {
oldRegister := registerCallbacks
oldPublish := publish
defer func() {
registerCallbacks = oldRegister
publish = oldPublish
}()
registerCallbacks = func(p *Controller) error {
return nil
}
waitPublish := sync.WaitGroup{}
var muEventsPublished sync.Mutex
eventsPublished := make(map[string][]byte)
publish = func(client mqtt.Client, topic string, payload *[]byte) {
muEventsPublished.Lock()
defer muEventsPublished.Unlock()
eventsPublished[topic] = *payload
waitPublish.Done()
}
steeringTopic := "topic/steering"
driveModeTopic := "topic/driveMode"
rcSteeringTopic := "topic/rcSteering"
tfSteeringTopic := "topic/tfSteering"
objectsTopic := "topic/objects"
type fields struct {
client mqtt.Client
steeringTopic string
muDriveMode sync.RWMutex
driveMode events.DriveMode
cancel chan interface{}
driveModeTopic string
rcSteeringTopic string
tfSteeringTopic string
objectsTopic string
muObjects sync.RWMutex
objects []*events.Object
corrector *GridCorrector
enableCorrection bool
enableCorrectionOnUser bool
}
type msgEvents struct {
driveMode events.DriveModeMessage
rcSteering events.SteeringMessage
tfSteering events.SteeringMessage
expectedSteering events.SteeringMessage
objects events.ObjectsMessage
}
tests := []struct {
name string
fields fields
msgEvents msgEvents
correctionOnObject float64
want events.SteeringMessage
wantErr bool
}{
{
name: "On user drive mode, none correction",
fields: fields{
driveMode: events.DriveMode_USER,
},
msgEvents: msgEvents{
driveMode: events.DriveModeMessage{DriveMode: events.DriveMode_USER},
rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
objects: events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
},
correctionOnObject: 0.5,
// Get rc value without correction
want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewController(nil,
steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser),
WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}),
)
go c.Start()
defer c.Stop()
// Publish events and wait generation of new steering message
waitPublish.Add(1)
c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode))
c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering))
c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering))
c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects))
waitPublish.Wait()
var msg events.SteeringMessage
muEventsPublished.Lock()
err := proto.Unmarshal(eventsPublished[steeringTopic], &msg)
if err != nil {
t.Errorf("unable to unmarshall response: %v", err)
t.Fail()
}
muEventsPublished.Unlock()
if msg.GetSteering() != tt.want.GetSteering() {
t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering())
}
})
}
}

View File

@ -8,17 +8,86 @@ import (
"os" "os"
) )
func NewCorrector(gridMap *GridMap, objectMoveFactors *GridMap) *Corrector { type Corrector interface {
return &Corrector{ AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64
gridMap: gridMap, }
objectMoveFactors: objectMoveFactors, type OptionCorrector func(c *GridCorrector)
func WithGridMap(configPath string) OptionCorrector {
var gm *GridMap
if configPath == "" {
zap.S().Warnf("no configuration defined for grid map, use default")
gm = &defaultGridMap
} else {
var err error
gm, err = loadConfig(configPath)
if err != nil {
zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err)
}
}
return func(c *GridCorrector) {
c.gridMap = gm
}
}
func WithObjectMoveFactors(configPath string) OptionCorrector {
var omf *GridMap
if configPath == "" {
zap.S().Warnf("no configuration defined for objects move factors, use default")
omf = &defaultObjectFactors
} else {
var err error
omf, err = loadConfig(configPath)
if err != nil {
zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err)
}
}
return func(c *GridCorrector) {
c.objectMoveFactors = omf
}
}
func loadConfig(configPath string) (*GridMap, error) {
content, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("unable to load grid-map config from file '%v': %w", configPath, err)
}
var gm GridMap
err = json.Unmarshal(content, &gm)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal json config '%s': %w", configPath, err)
}
return &gm, nil
}
func WithImageSize(width, height int) OptionCorrector {
return func(c *GridCorrector) {
c.imgWidth = width
c.imgHeight = height
}
}
func WidthDeltaMiddle(d float64) OptionCorrector {
return func(c *GridCorrector) {
c.deltaMiddle = d
}
}
func NewGridCorrector(options ...OptionCorrector) *GridCorrector {
c := &GridCorrector{
gridMap: &defaultGridMap,
objectMoveFactors: &defaultObjectFactors,
deltaMiddle: 0.1, deltaMiddle: 0.1,
imgWidth: 160, imgWidth: 160,
imgHeight: 120, imgHeight: 120,
} }
for _, o := range options {
o(c)
}
return c
} }
type Corrector struct { type GridCorrector struct {
gridMap *GridMap gridMap *GridMap
objectMoveFactors *GridMap objectMoveFactors *GridMap
deltaMiddle float64 deltaMiddle float64
@ -61,7 +130,7 @@ AdjustFromObjectPosition modify steering value according object positions
40% |-----|-----|-----|-----|-----|-----| 40% |-----|-----|-----|-----|-----|-----|
: | ... | ... | ... | ... | ... | ... | : | ... | ... | ... | ... | ... | ... |
*/ */
func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 { func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
if len(objects) == 0 { if len(objects) == 0 {
return currentSteering return currentSteering
} }
@ -104,7 +173,7 @@ func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []
} }
} }
func (c *Corrector) computeDeviation(nearest *events.Object) float64 { func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 {
var delta float64 var delta float64
var err error var err error
@ -123,7 +192,7 @@ func (c *Corrector) computeDeviation(nearest *events.Object) float64 {
return delta return delta
} }
func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) { func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) {
if len(objects) == 0 { if len(objects) == 0 {
return nil, fmt.Errorf("list objects must contain at least one object") return nil, fmt.Errorf("list objects must contain at least one object")
} }

View File

@ -57,31 +57,6 @@ var (
} }
) )
var (
defaultGridMap = GridMap{
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
Data: [][]float64{
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0., 0.25, -0.25, 0., 0.},
{0., 0.25, 0.5, -0.5, -0.25, 0.},
{0.25, 0.5, 1, -1, -0.5, -0.25},
},
}
defaultObjectFactors = GridMap{
DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
Data: [][]float64{
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0., 0.},
{0., 0.25, 0, 0, -0.25, 0.},
{0.5, 0.25, 0, 0, -0.5, -0.25},
},
}
)
func TestCorrector_AdjustFromObjectPosition(t *testing.T) { func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
type args struct { type args struct {
currentSteering float64 currentSteering float64
@ -184,7 +159,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := NewCorrector(&defaultGridMap, &defaultObjectFactors) c := NewGridCorrector()
if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want { if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want {
t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want) t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want)
} }
@ -229,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Corrector{} c := &GridCorrector{}
got, err := c.nearObject(tt.args.objects) got, err := c.nearObject(tt.args.objects)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr)
@ -438,3 +413,67 @@ func TestGridMap_ValueOf(t *testing.T) {
}) })
} }
} }
func TestWithGridMap(t *testing.T) {
type args struct {
config string
}
tests := []struct {
name string
args args
want GridMap
}{
{
name: "default value",
args: args{config: ""},
want: defaultGridMap,
},
{
name: "load config",
args: args{config: "test_data/config.json"},
want: defaultGridMap,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := GridCorrector{}
got := WithGridMap(tt.args.config)
got(&c)
if !reflect.DeepEqual(*c.gridMap, tt.want) {
t.Errorf("WithGridMap() = %v, want %v", *c.gridMap, tt.want)
}
})
}
}
func TestWithObjectMoveFactors(t *testing.T) {
type args struct {
config string
}
tests := []struct {
name string
args args
want GridMap
}{
{
name: "default value",
args: args{config: ""},
want: defaultObjectFactors,
},
{
name: "load config",
args: args{config: "test_data/omf-config.json"},
want: defaultObjectFactors,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := GridCorrector{}
got := WithObjectMoveFactors(tt.args.config)
got(&c)
if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) {
t.Errorf("WithObjectMoveFactors() = %v, want %v", *c.objectMoveFactors, tt.want)
}
})
}
}

View File

@ -0,0 +1,11 @@
{
"steering_steps":[-1, -0.66, -0.33, 0, 0.33, 0.66, 1],
"distance_steps": [0, 0.2, 0.4, 0.6, 0.8, 1],
"data": [
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0.25, 0, 0, -0.25, 0],
[0.5, 0.25, 0, 0, -0.5, -0.25]
]
}