feat(on-objects): Steering corrector implementation
This commit is contained in:
		@@ -16,6 +16,10 @@ const (
 | 
				
			|||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
	var mqttBroker, username, password, clientId string
 | 
						var mqttBroker, username, password, clientId string
 | 
				
			||||||
	var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string
 | 
						var steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string
 | 
				
			||||||
 | 
						var imgWidth, imgHeight int
 | 
				
			||||||
 | 
						var enableObjectsCorrection, enableObjectsCorrectionOnUserMode bool
 | 
				
			||||||
 | 
						var gridMapConfig, objectsMoveFactorsConfig string
 | 
				
			||||||
 | 
						var deltaMiddle float64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mqttQos := cli.InitIntFlag("MQTT_QOS", 0)
 | 
						mqttQos := cli.InitIntFlag("MQTT_QOS", 0)
 | 
				
			||||||
	_, mqttRetain := os.LookupEnv("MQTT_RETAIN")
 | 
						_, mqttRetain := os.LookupEnv("MQTT_RETAIN")
 | 
				
			||||||
@@ -27,7 +31,13 @@ func main() {
 | 
				
			|||||||
	flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set")
 | 
						flag.StringVar(&tfSteeringTopic, "mqtt-topic-tf-steering", os.Getenv("MQTT_TOPIC_TF_STEERING"), "Mqtt topic that contains tenorflow steering value, use MQTT_TOPIC_TF_STEERING if args not set")
 | 
				
			||||||
	flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set")
 | 
						flag.StringVar(&driveModeTopic, "mqtt-topic-drive-mode", os.Getenv("MQTT_TOPIC_DRIVE_MODE"), "Mqtt topic that contains DriveMode value, use MQTT_TOPIC_DRIVE_MODE if args not set")
 | 
				
			||||||
	flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set")
 | 
						flag.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains Objects from object detection value, use MQTT_TOPIC_OBJECTS if args not set")
 | 
				
			||||||
 | 
						flag.IntVar(&imgWidth, "image-width", 160, "Video pixels width")
 | 
				
			||||||
 | 
						flag.IntVar(&imgHeight, "image-height", 128, "Video pixels height")
 | 
				
			||||||
 | 
						flag.BoolVar(&enableObjectsCorrection, "enable-objects-correction", false, "Adjust steering to avoid objects")
 | 
				
			||||||
 | 
						flag.BoolVar(&enableObjectsCorrectionOnUserMode, "enable-objects-correction-user", false, "Adjust steering to avoid objects on user mode driving")
 | 
				
			||||||
 | 
						flag.StringVar(&gridMapConfig, "grid-map-config", "", "Json file path to configure grid object correction")
 | 
				
			||||||
 | 
						flag.StringVar(&objectsMoveFactorsConfig, "objects-move-factors-config", "", "Json file path to configure objects move corrections")
 | 
				
			||||||
 | 
						flag.Float64Var(&deltaMiddle, "delta-middle", 0.1, "Half Percent zone to interpret as straight")
 | 
				
			||||||
	logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
 | 
						logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	flag.Parse()
 | 
						flag.Parse()
 | 
				
			||||||
@@ -50,13 +60,36 @@ func main() {
 | 
				
			|||||||
	}()
 | 
						}()
 | 
				
			||||||
	zap.ReplaceGlobals(lgr)
 | 
						zap.ReplaceGlobals(lgr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						zap.S().Infof("steering topic                  : %s", steeringTopic)
 | 
				
			||||||
 | 
						zap.S().Infof("rc topic                        : %s", rcSteeringTopic)
 | 
				
			||||||
 | 
						zap.S().Infof("tflite steering topic           : %s", tfSteeringTopic)
 | 
				
			||||||
 | 
						zap.S().Infof("drive mode topic                : %s", driveModeTopic)
 | 
				
			||||||
 | 
						zap.S().Infof("objects topic                   : %s", objectsTopic)
 | 
				
			||||||
 | 
						zap.S().Infof("objects correction enabled      : %v", enableObjectsCorrection)
 | 
				
			||||||
 | 
						zap.S().Infof("objects correction on user mode : %v", enableObjectsCorrectionOnUserMode)
 | 
				
			||||||
 | 
						zap.S().Infof("grid map file config            : %v", gridMapConfig)
 | 
				
			||||||
 | 
						zap.S().Infof("objects move factors grid config: %v", objectsMoveFactorsConfig)
 | 
				
			||||||
 | 
						zap.S().Infof("image width x height            : %v x %v", imgWidth, imgHeight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	client, err := cli.Connect(mqttBroker, username, password, clientId)
 | 
						client, err := cli.Connect(mqttBroker, username, password, clientId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Fatalf("unable to connect to mqtt bus: %v", err)
 | 
							log.Fatalf("unable to connect to mqtt bus: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer client.Disconnect(50)
 | 
						defer client.Disconnect(50)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p := steering.NewController(client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic)
 | 
						p := steering.NewController(
 | 
				
			||||||
 | 
							client,
 | 
				
			||||||
 | 
							steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
 | 
				
			||||||
 | 
							steering.WithCorrector(
 | 
				
			||||||
 | 
								steering.NewGridCorrector(
 | 
				
			||||||
 | 
									steering.WidthDeltaMiddle(deltaMiddle),
 | 
				
			||||||
 | 
									steering.WithGridMap(gridMapConfig),
 | 
				
			||||||
 | 
									steering.WithObjectMoveFactors(objectsMoveFactorsConfig),
 | 
				
			||||||
 | 
									steering.WithImageSize(imgWidth, imgHeight),
 | 
				
			||||||
 | 
								),
 | 
				
			||||||
 | 
							),
 | 
				
			||||||
 | 
							steering.WithObjectsCorrectionEnabled(enableObjectsCorrection, enableObjectsCorrectionOnUserMode),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
	defer p.Stop()
 | 
						defer p.Stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	cli.HandleExit(p)
 | 
						cli.HandleExit(p)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
package steering
 | 
					package steering
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/cyrilix/robocar-protobuf/go/events"
 | 
				
			||||||
	"gocv.io/x/gocv"
 | 
						"gocv.io/x/gocv"
 | 
				
			||||||
	"image"
 | 
						"image"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -14,3 +15,43 @@ func GroupBBoxes(bboxes []image.Rectangle) []image.Rectangle {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return gocv.GroupRectangles(bboxes, 1, 0.2)
 | 
						return gocv.GroupRectangles(bboxes, 1, 0.2)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					func GroupObjects(objects []*events.Object, imgWidth, imgHeight int) []*events.Object {
 | 
				
			||||||
 | 
						if len(objects) == 0 {
 | 
				
			||||||
 | 
							return []*events.Object{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(objects) == 1 {
 | 
				
			||||||
 | 
							return []*events.Object{objects[0]}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rectangles := make([]image.Rectangle, 0, len(objects))
 | 
				
			||||||
 | 
						for _, o := range objects {
 | 
				
			||||||
 | 
							rectangles = append(rectangles, *objectToRect(o, imgWidth, imgHeight))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						grp := gocv.GroupRectangles(rectangles, 1, 0.2)
 | 
				
			||||||
 | 
						result := make([]*events.Object, 0, len(grp))
 | 
				
			||||||
 | 
						for _, r := range grp {
 | 
				
			||||||
 | 
							result = append(result, rectToObject(&r, imgWidth, imgHeight))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func objectToRect(object *events.Object, imgWidth, imgHeight int) *image.Rectangle {
 | 
				
			||||||
 | 
						r := image.Rect(
 | 
				
			||||||
 | 
							int(object.Left*float32(imgWidth)),
 | 
				
			||||||
 | 
							int(object.Top*float32(imgHeight)),
 | 
				
			||||||
 | 
							int(object.Right*float32(imgWidth)),
 | 
				
			||||||
 | 
							int(object.Bottom*float32(imgHeight)),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						return &r
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func rectToObject(r *image.Rectangle, imgWidth, imgHeight int) *events.Object {
 | 
				
			||||||
 | 
						return &events.Object{
 | 
				
			||||||
 | 
							Type:       events.TypeObject_ANY,
 | 
				
			||||||
 | 
							Left:       float32(r.Min.X) / float32(imgWidth),
 | 
				
			||||||
 | 
							Top:        float32(r.Min.Y) / float32(imgHeight),
 | 
				
			||||||
 | 
							Right:      float32(r.Max.X) / float32(imgWidth),
 | 
				
			||||||
 | 
							Bottom:     float32(r.Max.Y) / float32(imgHeight),
 | 
				
			||||||
 | 
							Confidence: -1,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ package steering
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/cyrilix/robocar-protobuf/go/events"
 | 
				
			||||||
	"go.uber.org/zap"
 | 
						"go.uber.org/zap"
 | 
				
			||||||
	"gocv.io/x/gocv"
 | 
						"gocv.io/x/gocv"
 | 
				
			||||||
	"image"
 | 
						"image"
 | 
				
			||||||
@@ -25,6 +26,7 @@ type BBox struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	dataBBoxes  map[string][]image.Rectangle
 | 
						dataBBoxes  map[string][]image.Rectangle
 | 
				
			||||||
 | 
						dataObjects map[string][]*events.Object
 | 
				
			||||||
	dataImages  map[string]*gocv.Mat
 | 
						dataImages  map[string]*gocv.Mat
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -32,6 +34,7 @@ func init() {
 | 
				
			|||||||
	// TODO: empty img without bbox
 | 
						// TODO: empty img without bbox
 | 
				
			||||||
	dataNames := []string{"01", "02", "03", "04"}
 | 
						dataNames := []string{"01", "02", "03", "04"}
 | 
				
			||||||
	dataBBoxes = make(map[string][]image.Rectangle, len(dataNames))
 | 
						dataBBoxes = make(map[string][]image.Rectangle, len(dataNames))
 | 
				
			||||||
 | 
						dataObjects = make(map[string][]*events.Object, len(dataNames))
 | 
				
			||||||
	dataImages = make(map[string]*gocv.Mat, len(dataNames))
 | 
						dataImages = make(map[string]*gocv.Mat, len(dataNames))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, dataName := range dataNames {
 | 
						for _, dataName := range dataNames {
 | 
				
			||||||
@@ -40,6 +43,7 @@ func init() {
 | 
				
			|||||||
			zap.S().Panicf("unable to load data test: %v", err)
 | 
								zap.S().Panicf("unable to load data test: %v", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		dataBBoxes[dataName] = bboxesToRectangles(bb, img.Cols(), img.Rows())
 | 
							dataBBoxes[dataName] = bboxesToRectangles(bb, img.Cols(), img.Rows())
 | 
				
			||||||
 | 
							dataObjects[dataName] = bboxesToObjects(bb)
 | 
				
			||||||
		dataImages[dataName] = img
 | 
							dataImages[dataName] = img
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -52,6 +56,20 @@ func bboxesToRectangles(bboxes []BBox, imgWidth, imgHeiht int) []image.Rectangle
 | 
				
			|||||||
	return rects
 | 
						return rects
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func bboxesToObjects(bboxes []BBox) []*events.Object {
 | 
				
			||||||
 | 
						objects := make([]*events.Object, 0, len(bboxes))
 | 
				
			||||||
 | 
						for _, bb := range bboxes {
 | 
				
			||||||
 | 
							objects = append(objects, &events.Object{
 | 
				
			||||||
 | 
								Type:       events.TypeObject_ANY,
 | 
				
			||||||
 | 
								Left:       bb.Left,
 | 
				
			||||||
 | 
								Top:        bb.Top,
 | 
				
			||||||
 | 
								Right:      bb.Right,
 | 
				
			||||||
 | 
								Bottom:     bb.Bottom,
 | 
				
			||||||
 | 
								Confidence: bb.Confidence,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return objects
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
func (bb *BBox) toRect(imgWidth, imgHeight int) image.Rectangle {
 | 
					func (bb *BBox) toRect(imgWidth, imgHeight int) image.Rectangle {
 | 
				
			||||||
	return image.Rect(
 | 
						return image.Rect(
 | 
				
			||||||
		int(bb.Left*float32(imgWidth)),
 | 
							int(bb.Left*float32(imgWidth)),
 | 
				
			||||||
@@ -228,3 +246,59 @@ func TestGroupBBoxes(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					func TestGroupObjects(t *testing.T) {
 | 
				
			||||||
 | 
						type args struct {
 | 
				
			||||||
 | 
							dataName string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name string
 | 
				
			||||||
 | 
							args args
 | 
				
			||||||
 | 
							want []*events.Object
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "groupbbox-01",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									dataName: "01",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []*events.Object{
 | 
				
			||||||
 | 
									{Left: 0.26660156, Top: 0.1706543, Right: 0.5258789, Bottom: 0.47583008, Confidence: 0.4482422},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "groupbbox-02",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									dataName: "02",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []*events.Object{
 | 
				
			||||||
 | 
									{Left: 0.15625, Top: 0.108333334, Right: 0.6875, Bottom: 0.6666667, Confidence: -1},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "groupbbox-03",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									dataName: "03",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []*events.Object{
 | 
				
			||||||
 | 
									{Top: 0.14166667, Right: 0.21875, Bottom: 0.64166665, Confidence: -1},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "groupbbox-04",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									dataName: "04",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []*events.Object{
 | 
				
			||||||
 | 
									{Left: 0.80625, Top: 0.083333336, Right: 0.99375, Bottom: 0.53333336, Confidence: -1},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								img := dataImages[tt.args.dataName]
 | 
				
			||||||
 | 
								got := GroupObjects(dataObjects[tt.args.dataName], img.Cols(), img.Rows())
 | 
				
			||||||
 | 
								if !reflect.DeepEqual(got, tt.want) {
 | 
				
			||||||
 | 
									t.Errorf("GroupObjects() = %v, want %v", got, tt.want)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
package steering
 | 
					package steering
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"github.com/cyrilix/robocar-base/service"
 | 
						"github.com/cyrilix/robocar-base/service"
 | 
				
			||||||
	"github.com/cyrilix/robocar-protobuf/go/events"
 | 
						"github.com/cyrilix/robocar-protobuf/go/events"
 | 
				
			||||||
	mqtt "github.com/eclipse/paho.mqtt.golang"
 | 
						mqtt "github.com/eclipse/paho.mqtt.golang"
 | 
				
			||||||
@@ -9,8 +10,48 @@ import (
 | 
				
			|||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string) *Controller {
 | 
					var (
 | 
				
			||||||
	return &Controller{
 | 
						defaultGridMap = GridMap{
 | 
				
			||||||
 | 
							DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
 | 
				
			||||||
 | 
							SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
 | 
				
			||||||
 | 
							Data: [][]float64{
 | 
				
			||||||
 | 
								{0., 0., 0., 0., 0., 0.},
 | 
				
			||||||
 | 
								{0., 0., 0., 0., 0., 0.},
 | 
				
			||||||
 | 
								{0., 0., 0.25, -0.25, 0., 0.},
 | 
				
			||||||
 | 
								{0., 0.25, 0.5, -0.5, -0.25, 0.},
 | 
				
			||||||
 | 
								{0.25, 0.5, 1, -1, -0.5, -0.25},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defaultObjectFactors = GridMap{
 | 
				
			||||||
 | 
							DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
 | 
				
			||||||
 | 
							SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
 | 
				
			||||||
 | 
							Data: [][]float64{
 | 
				
			||||||
 | 
								{0., 0., 0., 0., 0., 0.},
 | 
				
			||||||
 | 
								{0., 0., 0., 0., 0., 0.},
 | 
				
			||||||
 | 
								{0., 0., 0., 0., 0., 0.},
 | 
				
			||||||
 | 
								{0., 0.25, 0, 0, -0.25, 0.},
 | 
				
			||||||
 | 
								{0.5, 0.25, 0, 0, -0.5, -0.25},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Option func(c *Controller)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithCorrector(c Corrector) Option {
 | 
				
			||||||
 | 
						return func(ctrl *Controller) {
 | 
				
			||||||
 | 
							ctrl.corrector = c
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithObjectsCorrectionEnabled(enabled, enabledOnUserDrive bool) Option {
 | 
				
			||||||
 | 
						return func(ctrl *Controller) {
 | 
				
			||||||
 | 
							ctrl.enableCorrection = enabled
 | 
				
			||||||
 | 
							ctrl.enableCorrectionOnUser = enabledOnUserDrive
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic string, options ...Option) *Controller {
 | 
				
			||||||
 | 
						c := &Controller{
 | 
				
			||||||
		client:          client,
 | 
							client:          client,
 | 
				
			||||||
		steeringTopic:   steeringTopic,
 | 
							steeringTopic:   steeringTopic,
 | 
				
			||||||
		driveModeTopic:  driveModeTopic,
 | 
							driveModeTopic:  driveModeTopic,
 | 
				
			||||||
@@ -18,8 +59,12 @@ func NewController(client mqtt.Client, steeringTopic, driveModeTopic, rcSteering
 | 
				
			|||||||
		tfSteeringTopic: tfSteeringTopic,
 | 
							tfSteeringTopic: tfSteeringTopic,
 | 
				
			||||||
		objectsTopic:    objectsTopic,
 | 
							objectsTopic:    objectsTopic,
 | 
				
			||||||
		driveMode:       events.DriveMode_USER,
 | 
							driveMode:       events.DriveMode_USER,
 | 
				
			||||||
 | 
							corrector:       NewGridCorrector(),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						for _, o := range options {
 | 
				
			||||||
 | 
							o(c)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Controller struct {
 | 
					type Controller struct {
 | 
				
			||||||
@@ -35,26 +80,28 @@ type Controller struct {
 | 
				
			|||||||
	muObjects sync.RWMutex
 | 
						muObjects sync.RWMutex
 | 
				
			||||||
	objects   []*events.Object
 | 
						objects   []*events.Object
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	debug bool
 | 
						corrector              Corrector
 | 
				
			||||||
 | 
						enableCorrection       bool
 | 
				
			||||||
 | 
						enableCorrectionOnUser bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Controller) Start() error {
 | 
					func (c *Controller) Start() error {
 | 
				
			||||||
	if err := registerCallbacks(p); err != nil {
 | 
						if err := registerCallbacks(c); err != nil {
 | 
				
			||||||
		zap.S().Errorf("unable to rgeister callbacks: %v", err)
 | 
							zap.S().Errorf("unable to register callbacks: %v", err)
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p.cancel = make(chan interface{})
 | 
						c.cancel = make(chan interface{})
 | 
				
			||||||
	<-p.cancel
 | 
						<-c.cancel
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Controller) Stop() {
 | 
					func (c *Controller) Stop() {
 | 
				
			||||||
	close(p.cancel)
 | 
						close(c.cancel)
 | 
				
			||||||
	service.StopService("throttle", p.client, p.driveModeTopic, p.rcSteeringTopic, p.tfSteeringTopic)
 | 
						service.StopService("throttle", c.client, c.driveModeTopic, c.rcSteeringTopic, c.tfSteeringTopic)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
 | 
					func (c *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
 | 
				
			||||||
	var msg events.ObjectsMessage
 | 
						var msg events.ObjectsMessage
 | 
				
			||||||
	err := proto.Unmarshal(message.Payload(), &msg)
 | 
						err := proto.Unmarshal(message.Payload(), &msg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -62,12 +109,13 @@ func (p *Controller) onObjects(_ mqtt.Client, message mqtt.Message) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p.muObjects.Lock()
 | 
						c.muObjects.Lock()
 | 
				
			||||||
	defer p.muObjects.Unlock()
 | 
						defer c.muObjects.Unlock()
 | 
				
			||||||
	p.objects = msg.GetObjects()
 | 
						c.objects = msg.GetObjects()
 | 
				
			||||||
 | 
						zap.S().Debugf("%v object(s) received", len(c.objects))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
 | 
					func (c *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
 | 
				
			||||||
	var msg events.DriveModeMessage
 | 
						var msg events.DriveModeMessage
 | 
				
			||||||
	err := proto.Unmarshal(message.Payload(), &msg)
 | 
						err := proto.Unmarshal(message.Payload(), &msg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -75,46 +123,89 @@ func (p *Controller) onDriveMode(_ mqtt.Client, message mqtt.Message) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p.muDriveMode.Lock()
 | 
						c.muDriveMode.Lock()
 | 
				
			||||||
	defer p.muDriveMode.Unlock()
 | 
						defer c.muDriveMode.Unlock()
 | 
				
			||||||
	p.driveMode = msg.GetDriveMode()
 | 
						c.driveMode = msg.GetDriveMode()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) {
 | 
					func (c *Controller) onRCSteering(_ mqtt.Client, message mqtt.Message) {
 | 
				
			||||||
	p.muDriveMode.RLock()
 | 
						c.muDriveMode.RLock()
 | 
				
			||||||
	defer p.muDriveMode.RUnlock()
 | 
						defer c.muDriveMode.RUnlock()
 | 
				
			||||||
	if p.debug {
 | 
					
 | 
				
			||||||
		var evt events.SteeringMessage
 | 
						if c.driveMode != events.DriveMode_USER {
 | 
				
			||||||
		err := proto.Unmarshal(message.Payload(), &evt)
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						payload := message.Payload()
 | 
				
			||||||
 | 
						evt := &events.SteeringMessage{}
 | 
				
			||||||
 | 
						err := proto.Unmarshal(payload, evt)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		zap.S().Debugf("unable to unmarshal rc event: %v", err)
 | 
							zap.S().Debugf("unable to unmarshal rc event: %v", err)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering())
 | 
							zap.S().Debugf("receive steering message from radio command: %0.00f", evt.GetSteering())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	}
 | 
					
 | 
				
			||||||
	if p.driveMode == events.DriveMode_USER {
 | 
						if c.enableCorrection && c.enableCorrectionOnUser {
 | 
				
			||||||
		// Republish same content
 | 
							payload, err = c.adjustSteering(evt)
 | 
				
			||||||
		payload := message.Payload()
 | 
					 | 
				
			||||||
		publish(p.client, p.steeringTopic, &payload)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
func (p *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) {
 | 
					 | 
				
			||||||
	p.muDriveMode.RLock()
 | 
					 | 
				
			||||||
	defer p.muDriveMode.RUnlock()
 | 
					 | 
				
			||||||
	if p.debug {
 | 
					 | 
				
			||||||
		var evt events.SteeringMessage
 | 
					 | 
				
			||||||
		err := proto.Unmarshal(message.Payload(), &evt)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			zap.S().Debugf("unable to unmarshal tensorflow event: %v", err)
 | 
								zap.S().Errorf("unable to adjust steering, skip message: %v", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						publish(c.client, c.steeringTopic, &payload)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Controller) onTFSteering(_ mqtt.Client, message mqtt.Message) {
 | 
				
			||||||
 | 
						c.muDriveMode.RLock()
 | 
				
			||||||
 | 
						defer c.muDriveMode.RUnlock()
 | 
				
			||||||
 | 
						if c.driveMode != events.DriveMode_PILOT {
 | 
				
			||||||
 | 
							// User mode, skip new message
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						evt := &events.SteeringMessage{}
 | 
				
			||||||
 | 
						err := proto.Unmarshal(message.Payload(), evt)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							zap.S().Errorf("unable to unmarshal tensorflow event: %v", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering())
 | 
							zap.S().Debugf("receive steering message from tensorflow: %0.00f", evt.GetSteering())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	}
 | 
					
 | 
				
			||||||
	if p.driveMode == events.DriveMode_PILOT {
 | 
					 | 
				
			||||||
		// Republish same content
 | 
					 | 
				
			||||||
	payload := message.Payload()
 | 
						payload := message.Payload()
 | 
				
			||||||
		publish(p.client, p.steeringTopic, &payload)
 | 
						if c.enableCorrection {
 | 
				
			||||||
 | 
							payload, err = c.adjustSteering(evt)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								zap.S().Errorf("unable to adjust steering, skip message: %v", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						publish(c.client, c.steeringTopic, &payload)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Controller) adjustSteering(evt *events.SteeringMessage) ([]byte, error) {
 | 
				
			||||||
 | 
						steering := float64(evt.GetSteering())
 | 
				
			||||||
 | 
						steering = c.corrector.AdjustFromObjectPosition(steering, c.Objects())
 | 
				
			||||||
 | 
						zap.S().Debugf("adjust steering to avoid objects: %v -> %v", evt.GetSteering(), steering)
 | 
				
			||||||
 | 
						evt.Steering = float32(steering)
 | 
				
			||||||
 | 
						// override payload content
 | 
				
			||||||
 | 
						payload, err := proto.Marshal(evt)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("unable to marshal steering message with new value, skip message: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return payload, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Controller) Objects() []*events.Object {
 | 
				
			||||||
 | 
						c.muObjects.RLock()
 | 
				
			||||||
 | 
						defer c.muObjects.RUnlock()
 | 
				
			||||||
 | 
						res := make([]*events.Object, 0, len(c.objects))
 | 
				
			||||||
 | 
						for _, o := range c.objects {
 | 
				
			||||||
 | 
							res = append(res, o)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						zap.S().Debugf("copy object from %v to %v", c.objects, res)
 | 
				
			||||||
 | 
						return res
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var registerCallbacks = func(p *Controller) error {
 | 
					var registerCallbacks = func(p *Controller) error {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -123,3 +123,198 @@ func TestDefaultSteering(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type StaticCorrector struct {
 | 
				
			||||||
 | 
						delta float64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *StaticCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
 | 
				
			||||||
 | 
						return s.delta
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestController_Start(t *testing.T) {
 | 
				
			||||||
 | 
						oldRegister := registerCallbacks
 | 
				
			||||||
 | 
						oldPublish := publish
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							registerCallbacks = oldRegister
 | 
				
			||||||
 | 
							publish = oldPublish
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						registerCallbacks = func(p *Controller) error {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						waitPublish := sync.WaitGroup{}
 | 
				
			||||||
 | 
						var muEventsPublished sync.Mutex
 | 
				
			||||||
 | 
						eventsPublished := make(map[string][]byte)
 | 
				
			||||||
 | 
						publish = func(client mqtt.Client, topic string, payload *[]byte) {
 | 
				
			||||||
 | 
							muEventsPublished.Lock()
 | 
				
			||||||
 | 
							defer muEventsPublished.Unlock()
 | 
				
			||||||
 | 
							eventsPublished[topic] = *payload
 | 
				
			||||||
 | 
							waitPublish.Done()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						steeringTopic := "topic/steering"
 | 
				
			||||||
 | 
						driveModeTopic := "topic/driveMode"
 | 
				
			||||||
 | 
						rcSteeringTopic := "topic/rcSteering"
 | 
				
			||||||
 | 
						tfSteeringTopic := "topic/tfSteering"
 | 
				
			||||||
 | 
						objectsTopic := "topic/objects"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type fields struct {
 | 
				
			||||||
 | 
							driveMode              events.DriveMode
 | 
				
			||||||
 | 
							enableCorrection       bool
 | 
				
			||||||
 | 
							enableCorrectionOnUser bool
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						type msgEvents struct {
 | 
				
			||||||
 | 
							driveMode        events.DriveModeMessage
 | 
				
			||||||
 | 
							rcSteering       events.SteeringMessage
 | 
				
			||||||
 | 
							tfSteering       events.SteeringMessage
 | 
				
			||||||
 | 
							expectedSteering events.SteeringMessage
 | 
				
			||||||
 | 
							objects          events.ObjectsMessage
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name               string
 | 
				
			||||||
 | 
							fields             fields
 | 
				
			||||||
 | 
							msgEvents          msgEvents
 | 
				
			||||||
 | 
							correctionOnObject float64
 | 
				
			||||||
 | 
							want               events.SteeringMessage
 | 
				
			||||||
 | 
							wantErr            bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "On user drive mode, none correction",
 | 
				
			||||||
 | 
								fields: fields{
 | 
				
			||||||
 | 
									driveMode:              events.DriveMode_USER,
 | 
				
			||||||
 | 
									enableCorrection:       false,
 | 
				
			||||||
 | 
									enableCorrectionOnUser: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								msgEvents: msgEvents{
 | 
				
			||||||
 | 
									driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_USER},
 | 
				
			||||||
 | 
									rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
									tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
									objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								correctionOnObject: 0.5,
 | 
				
			||||||
 | 
								// Get rc value without correction
 | 
				
			||||||
 | 
								want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "On pilot drive mode, none correction",
 | 
				
			||||||
 | 
								fields: fields{
 | 
				
			||||||
 | 
									driveMode:              events.DriveMode_PILOT,
 | 
				
			||||||
 | 
									enableCorrection:       false,
 | 
				
			||||||
 | 
									enableCorrectionOnUser: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								msgEvents: msgEvents{
 | 
				
			||||||
 | 
									driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_PILOT},
 | 
				
			||||||
 | 
									rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
									tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
									objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								correctionOnObject: 0.5,
 | 
				
			||||||
 | 
								// Get rc value without correction
 | 
				
			||||||
 | 
								want: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "On pilot drive mode, correction enabled",
 | 
				
			||||||
 | 
								fields: fields{
 | 
				
			||||||
 | 
									driveMode:              events.DriveMode_PILOT,
 | 
				
			||||||
 | 
									enableCorrection:       true,
 | 
				
			||||||
 | 
									enableCorrectionOnUser: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								msgEvents: msgEvents{
 | 
				
			||||||
 | 
									driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_PILOT},
 | 
				
			||||||
 | 
									rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
									tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
									objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								correctionOnObject: 0.5,
 | 
				
			||||||
 | 
								// Get rc value without correction
 | 
				
			||||||
 | 
								want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "On pilot drive mode, all corrections enabled",
 | 
				
			||||||
 | 
								fields: fields{
 | 
				
			||||||
 | 
									driveMode:              events.DriveMode_PILOT,
 | 
				
			||||||
 | 
									enableCorrection:       true,
 | 
				
			||||||
 | 
									enableCorrectionOnUser: true,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								msgEvents: msgEvents{
 | 
				
			||||||
 | 
									driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_PILOT},
 | 
				
			||||||
 | 
									rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
									tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
									objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								correctionOnObject: 0.5,
 | 
				
			||||||
 | 
								// Get rc value without correction
 | 
				
			||||||
 | 
								want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "On user drive mode, only correction PILOT enabled",
 | 
				
			||||||
 | 
								fields: fields{
 | 
				
			||||||
 | 
									driveMode:              events.DriveMode_PILOT,
 | 
				
			||||||
 | 
									enableCorrection:       true,
 | 
				
			||||||
 | 
									enableCorrectionOnUser: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								msgEvents: msgEvents{
 | 
				
			||||||
 | 
									driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_USER},
 | 
				
			||||||
 | 
									rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
									tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
									objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								correctionOnObject: 0.5,
 | 
				
			||||||
 | 
								// Get rc value without correction
 | 
				
			||||||
 | 
								want: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "On user drive mode, all corrections enabled",
 | 
				
			||||||
 | 
								fields: fields{
 | 
				
			||||||
 | 
									driveMode:              events.DriveMode_USER,
 | 
				
			||||||
 | 
									enableCorrection:       true,
 | 
				
			||||||
 | 
									enableCorrectionOnUser: true,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								msgEvents: msgEvents{
 | 
				
			||||||
 | 
									driveMode:  events.DriveModeMessage{DriveMode: events.DriveMode_USER},
 | 
				
			||||||
 | 
									rcSteering: events.SteeringMessage{Steering: 0.3, Confidence: 1.0},
 | 
				
			||||||
 | 
									tfSteering: events.SteeringMessage{Steering: 0.4, Confidence: 1.0},
 | 
				
			||||||
 | 
									objects:    events.ObjectsMessage{Objects: []*events.Object{&objectOnMiddleNear}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								correctionOnObject: 0.5,
 | 
				
			||||||
 | 
								// Get rc value without correction
 | 
				
			||||||
 | 
								want: events.SteeringMessage{Steering: 0.5, Confidence: 1.0},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								c := NewController(nil,
 | 
				
			||||||
 | 
									steeringTopic, driveModeTopic, rcSteeringTopic, tfSteeringTopic, objectsTopic,
 | 
				
			||||||
 | 
									WithObjectsCorrectionEnabled(tt.fields.enableCorrection, tt.fields.enableCorrectionOnUser),
 | 
				
			||||||
 | 
									WithCorrector(&StaticCorrector{delta: tt.correctionOnObject}),
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
								go c.Start()
 | 
				
			||||||
 | 
								time.Sleep(1 * time.Millisecond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Publish events and wait generation of new steering message
 | 
				
			||||||
 | 
								waitPublish.Add(1)
 | 
				
			||||||
 | 
								c.onDriveMode(nil, testtools.NewFakeMessageFromProtobuf(driveModeTopic, &tt.msgEvents.driveMode))
 | 
				
			||||||
 | 
								c.onRCSteering(nil, testtools.NewFakeMessageFromProtobuf(rcSteeringTopic, &tt.msgEvents.rcSteering))
 | 
				
			||||||
 | 
								c.onTFSteering(nil, testtools.NewFakeMessageFromProtobuf(tfSteeringTopic, &tt.msgEvents.tfSteering))
 | 
				
			||||||
 | 
								c.onObjects(nil, testtools.NewFakeMessageFromProtobuf(objectsTopic, &tt.msgEvents.objects))
 | 
				
			||||||
 | 
								waitPublish.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								var msg events.SteeringMessage
 | 
				
			||||||
 | 
								muEventsPublished.Lock()
 | 
				
			||||||
 | 
								err := proto.Unmarshal(eventsPublished[steeringTopic], &msg)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Errorf("unable to unmarshall response: %v", err)
 | 
				
			||||||
 | 
									t.Fail()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								muEventsPublished.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if msg.GetSteering() != tt.want.GetSteering() {
 | 
				
			||||||
 | 
									t.Errorf("bad msg value for mode %v: %v, wants %v", c.driveMode.String(), msg.GetSteering(), tt.want.GetSteering())
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,8 +8,90 @@ import (
 | 
				
			|||||||
	"os"
 | 
						"os"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Corrector struct {
 | 
					type Corrector interface {
 | 
				
			||||||
 | 
						AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					type OptionCorrector func(c *GridCorrector)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithGridMap(configPath string) OptionCorrector {
 | 
				
			||||||
 | 
						var gm *GridMap
 | 
				
			||||||
 | 
						if configPath == "" {
 | 
				
			||||||
 | 
							zap.S().Warnf("no configuration defined for grid map, use default")
 | 
				
			||||||
 | 
							gm = &defaultGridMap
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							var err error
 | 
				
			||||||
 | 
							gm, err = loadConfig(configPath)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								zap.S().Panicf("unable to load grid-map config from file '%v': %w", configPath, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return func(c *GridCorrector) {
 | 
				
			||||||
 | 
							c.gridMap = gm
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithObjectMoveFactors(configPath string) OptionCorrector {
 | 
				
			||||||
 | 
						var omf *GridMap
 | 
				
			||||||
 | 
						if configPath == "" {
 | 
				
			||||||
 | 
							zap.S().Warnf("no configuration defined for objects move factors, use default")
 | 
				
			||||||
 | 
							omf = &defaultObjectFactors
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							var err error
 | 
				
			||||||
 | 
							omf, err = loadConfig(configPath)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								zap.S().Panicf("unable to load objects move factors config from file '%v': %w", configPath, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return func(c *GridCorrector) {
 | 
				
			||||||
 | 
							c.objectMoveFactors = omf
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func loadConfig(configPath string) (*GridMap, error) {
 | 
				
			||||||
 | 
						content, err := os.ReadFile(configPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("unable to load grid-map config from file '%v': %w", configPath, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var gm GridMap
 | 
				
			||||||
 | 
						err = json.Unmarshal(content, &gm)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("unable to unmarshal json config '%s': %w", configPath, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &gm, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithImageSize(width, height int) OptionCorrector {
 | 
				
			||||||
 | 
						return func(c *GridCorrector) {
 | 
				
			||||||
 | 
							c.imgWidth = width
 | 
				
			||||||
 | 
							c.imgHeight = height
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WidthDeltaMiddle(d float64) OptionCorrector {
 | 
				
			||||||
 | 
						return func(c *GridCorrector) {
 | 
				
			||||||
 | 
							c.deltaMiddle = d
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func NewGridCorrector(options ...OptionCorrector) *GridCorrector {
 | 
				
			||||||
 | 
						c := &GridCorrector{
 | 
				
			||||||
 | 
							gridMap:           &defaultGridMap,
 | 
				
			||||||
 | 
							objectMoveFactors: &defaultObjectFactors,
 | 
				
			||||||
 | 
							deltaMiddle:       0.1,
 | 
				
			||||||
 | 
							imgWidth:          160,
 | 
				
			||||||
 | 
							imgHeight:         120,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, o := range options {
 | 
				
			||||||
 | 
							o(c)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return c
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GridCorrector struct {
 | 
				
			||||||
	gridMap             *GridMap
 | 
						gridMap             *GridMap
 | 
				
			||||||
 | 
						objectMoveFactors   *GridMap
 | 
				
			||||||
 | 
						deltaMiddle         float64
 | 
				
			||||||
 | 
						imgWidth, imgHeight int
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/*
 | 
					/*
 | 
				
			||||||
@@ -39,24 +121,64 @@ AdjustFromObjectPosition modify steering value according object positions
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 3. If current steering != 0 (turn on left or right), shift right and left values proportionnaly to current steering and
 | 
					 3. If current steering != 0 (turn on left or right), shift right and left values proportionnaly to current steering and
 | 
				
			||||||
    apply 2.
 | 
					    apply 2.
 | 
				
			||||||
*/
 | 
					 | 
				
			||||||
func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
 | 
					 | 
				
			||||||
	// TODO, group rectangle
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :  -1   -0.66 -0.33   0    0.33  0.66   1
 | 
				
			||||||
 | 
					    0%  |-----|-----|-----|-----|-----|-----|
 | 
				
			||||||
 | 
					    :   |  0  |  0  |  0  |  0  |  0  |  0  |
 | 
				
			||||||
 | 
					    20% |-----|-----|-----|-----|-----|-----|
 | 
				
			||||||
 | 
					    :   | 0.2 | 0.1 |  0  |  0  |-0.1 |-0.2 |
 | 
				
			||||||
 | 
					    40% |-----|-----|-----|-----|-----|-----|
 | 
				
			||||||
 | 
					    :   | ... | ... | ... | ... | ... | ... |
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					func (c *GridCorrector) AdjustFromObjectPosition(currentSteering float64, objects []*events.Object) float64 {
 | 
				
			||||||
 | 
						zap.S().Debugf("%v objects to avoid", len(objects))
 | 
				
			||||||
	if len(objects) == 0 {
 | 
						if len(objects) == 0 {
 | 
				
			||||||
		return currentSteering
 | 
							return currentSteering
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						grpObjs := GroupObjects(objects, c.imgWidth, c.imgHeight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// get nearest object
 | 
						// get nearest object
 | 
				
			||||||
	nearest, err := c.nearObject(objects)
 | 
						nearest, err := c.nearObject(grpObjs)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		zap.S().Warnf("unexpected error on nearest seach object, ignore objects: %v", err)
 | 
							zap.S().Warnf("unexpected error on nearest search object, ignore objects: %v", err)
 | 
				
			||||||
		return currentSteering
 | 
							return currentSteering
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if currentSteering > -0.1 && currentSteering < 0.1 {
 | 
						if currentSteering > -1*c.deltaMiddle && currentSteering < c.deltaMiddle {
 | 
				
			||||||
 | 
							// Straight
 | 
				
			||||||
 | 
							return currentSteering + c.computeDeviation(nearest)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							// Turn to right or left, so search to avoid collision with objects on the right
 | 
				
			||||||
 | 
							// Apply factor to object to move it at middle. This factor is function of distance
 | 
				
			||||||
 | 
							factor, err := c.objectMoveFactors.ValueOf(float64(nearest.Right), float64(nearest.Bottom))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								zap.S().Warnf("unable to compute factor to apply to object: %v", err)
 | 
				
			||||||
 | 
								return currentSteering
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							objMoved := events.Object{
 | 
				
			||||||
 | 
								Type:       nearest.Type,
 | 
				
			||||||
 | 
								Left:       nearest.Left + float32(currentSteering*factor),
 | 
				
			||||||
 | 
								Top:        nearest.Top,
 | 
				
			||||||
 | 
								Right:      nearest.Right + float32(currentSteering*factor),
 | 
				
			||||||
 | 
								Bottom:     nearest.Bottom,
 | 
				
			||||||
 | 
								Confidence: nearest.Confidence,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							result := currentSteering + c.computeDeviation(&objMoved)
 | 
				
			||||||
 | 
							if result < -1. {
 | 
				
			||||||
 | 
								result = -1.
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if result > 1. {
 | 
				
			||||||
 | 
								result = 1.
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return result
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *GridCorrector) computeDeviation(nearest *events.Object) float64 {
 | 
				
			||||||
	var delta float64
 | 
						var delta float64
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						zap.S().Debugf("search delta value for bottom limit: %v", nearest.Bottom)
 | 
				
			||||||
	if nearest.Left < 0 && nearest.Right < 0 {
 | 
						if nearest.Left < 0 && nearest.Right < 0 {
 | 
				
			||||||
		delta, err = c.gridMap.ValueOf(float64(nearest.Right)*2-1., float64(nearest.Bottom))
 | 
							delta, err = c.gridMap.ValueOf(float64(nearest.Right)*2-1., float64(nearest.Bottom))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -69,15 +191,11 @@ func (c *Corrector) AdjustFromObjectPosition(currentSteering float64, objects []
 | 
				
			|||||||
		zap.S().Warnf("unable to compute delta to apply to steering, skip correction: %v", err)
 | 
							zap.S().Warnf("unable to compute delta to apply to steering, skip correction: %v", err)
 | 
				
			||||||
		delta = 0
 | 
							delta = 0
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
		return currentSteering + delta
 | 
						zap.S().Debugf("new deviation computed: %v", delta)
 | 
				
			||||||
	}
 | 
						return delta
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Search if current steering is near of Right or Left
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return currentSteering
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Corrector) nearObject(objects []*events.Object) (*events.Object, error) {
 | 
					func (c *GridCorrector) nearObject(objects []*events.Object) (*events.Object, error) {
 | 
				
			||||||
	if len(objects) == 0 {
 | 
						if len(objects) == 0 {
 | 
				
			||||||
		return nil, fmt.Errorf("list objects must contain at least one object")
 | 
							return nil, fmt.Errorf("list objects must contain at least one object")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -39,19 +39,21 @@ var (
 | 
				
			|||||||
		Bottom:     0.9,
 | 
							Bottom:     0.9,
 | 
				
			||||||
		Confidence: 0.9,
 | 
							Confidence: 0.9,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
)
 | 
						objectOnRightNear = events.Object{
 | 
				
			||||||
 | 
							Type:       events.TypeObject_ANY,
 | 
				
			||||||
var (
 | 
							Left:       0.7,
 | 
				
			||||||
	defaultGridMap = GridMap{
 | 
							Top:        0.8,
 | 
				
			||||||
		DistanceSteps: []float64{0., 0.2, 0.4, 0.6, 0.8, 1.},
 | 
							Right:      0.9,
 | 
				
			||||||
		SteeringSteps: []float64{-1., -0.66, -0.33, 0., 0.33, 0.66, 1.},
 | 
							Bottom:     0.9,
 | 
				
			||||||
		Data: [][]float64{
 | 
							Confidence: 0.9,
 | 
				
			||||||
			{0., 0., 0., 0., 0., 0.},
 | 
						}
 | 
				
			||||||
			{0., 0., 0., 0., 0., 0.},
 | 
						objectOnLeftNear = events.Object{
 | 
				
			||||||
			{0., 0., 0.25, -0.25, 0., 0.},
 | 
							Type:       events.TypeObject_ANY,
 | 
				
			||||||
			{0., 0.25, 0.5, -0.5, -0.25, 0.},
 | 
							Left:       0.1,
 | 
				
			||||||
			{0.25, 0.5, 1, -1, -0.5, -0.25},
 | 
							Top:        0.8,
 | 
				
			||||||
		},
 | 
							Right:      0.3,
 | 
				
			||||||
 | 
							Bottom:     0.9,
 | 
				
			||||||
 | 
							Confidence: 0.9,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -128,7 +130,7 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
 | 
				
			|||||||
				currentSteering: -0.9,
 | 
									currentSteering: -0.9,
 | 
				
			||||||
				objects:         []*events.Object{&objectOnMiddleNear},
 | 
									objects:         []*events.Object{&objectOnMiddleNear},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: -0.4,
 | 
								want: -1,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "run to right with 1 near object",
 | 
								name: "run to right with 1 near object",
 | 
				
			||||||
@@ -136,16 +138,28 @@ func TestCorrector_AdjustFromObjectPosition(t *testing.T) {
 | 
				
			|||||||
				currentSteering: 0.9,
 | 
									currentSteering: 0.9,
 | 
				
			||||||
				objects:         []*events.Object{&objectOnMiddleNear},
 | 
									objects:         []*events.Object{&objectOnMiddleNear},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: 0.4,
 | 
								want: 1.,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "run to right with 1 near object on the right",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									currentSteering: 0.9,
 | 
				
			||||||
 | 
									objects:         []*events.Object{&objectOnRightNear},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: 1.,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "run to left with 1 near object on the left",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									currentSteering: -0.9,
 | 
				
			||||||
 | 
									objects:         []*events.Object{&objectOnLeftNear},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: -0.65,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Todo Object on left/right near/distant
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			c := &Corrector{
 | 
								c := NewGridCorrector()
 | 
				
			||||||
				gridMap: &defaultGridMap,
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want {
 | 
								if got := c.AdjustFromObjectPosition(tt.args.currentSteering, tt.args.objects); got != tt.want {
 | 
				
			||||||
				t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want)
 | 
									t.Errorf("AdjustFromObjectPosition() = %v, want %v", got, tt.want)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -190,7 +204,7 @@ func TestCorrector_nearObject(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			c := &Corrector{}
 | 
								c := &GridCorrector{}
 | 
				
			||||||
			got, err := c.nearObject(tt.args.objects)
 | 
								got, err := c.nearObject(tt.args.objects)
 | 
				
			||||||
			if (err != nil) != tt.wantErr {
 | 
								if (err != nil) != tt.wantErr {
 | 
				
			||||||
				t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr)
 | 
									t.Errorf("nearObject() error = %v, wantErr %v", err, tt.wantErr)
 | 
				
			||||||
@@ -399,3 +413,67 @@ func TestGridMap_ValueOf(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWithGridMap(t *testing.T) {
 | 
				
			||||||
 | 
						type args struct {
 | 
				
			||||||
 | 
							config string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name string
 | 
				
			||||||
 | 
							args args
 | 
				
			||||||
 | 
							want GridMap
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "default value",
 | 
				
			||||||
 | 
								args: args{config: ""},
 | 
				
			||||||
 | 
								want: defaultGridMap,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "load config",
 | 
				
			||||||
 | 
								args: args{config: "test_data/config.json"},
 | 
				
			||||||
 | 
								want: defaultGridMap,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								c := GridCorrector{}
 | 
				
			||||||
 | 
								got := WithGridMap(tt.args.config)
 | 
				
			||||||
 | 
								got(&c)
 | 
				
			||||||
 | 
								if !reflect.DeepEqual(*c.gridMap, tt.want) {
 | 
				
			||||||
 | 
									t.Errorf("WithGridMap() = %v, want %v", *c.gridMap, tt.want)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWithObjectMoveFactors(t *testing.T) {
 | 
				
			||||||
 | 
						type args struct {
 | 
				
			||||||
 | 
							config string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name string
 | 
				
			||||||
 | 
							args args
 | 
				
			||||||
 | 
							want GridMap
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "default value",
 | 
				
			||||||
 | 
								args: args{config: ""},
 | 
				
			||||||
 | 
								want: defaultObjectFactors,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "load config",
 | 
				
			||||||
 | 
								args: args{config: "test_data/omf-config.json"},
 | 
				
			||||||
 | 
								want: defaultObjectFactors,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								c := GridCorrector{}
 | 
				
			||||||
 | 
								got := WithObjectMoveFactors(tt.args.config)
 | 
				
			||||||
 | 
								got(&c)
 | 
				
			||||||
 | 
								if !reflect.DeepEqual(*c.objectMoveFactors, tt.want) {
 | 
				
			||||||
 | 
									t.Errorf("WithObjectMoveFactors() = %v, want %v", *c.objectMoveFactors, tt.want)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										11
									
								
								pkg/steering/test_data/omf-config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								pkg/steering/test_data/omf-config.json
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "steering_steps":[-1, -0.66, -0.33, 0, 0.33, 0.66, 1],
 | 
				
			||||||
 | 
					  "distance_steps": [0, 0.2, 0.4, 0.6, 0.8, 1],
 | 
				
			||||||
 | 
					  "data": [
 | 
				
			||||||
 | 
						[0, 0, 0, 0, 0, 0],
 | 
				
			||||||
 | 
						[0, 0, 0, 0, 0, 0],
 | 
				
			||||||
 | 
						[0, 0, 0, 0, 0, 0],
 | 
				
			||||||
 | 
						[0, 0.25, 0, 0, -0.25, 0],
 | 
				
			||||||
 | 
						[0.5, 0.25, 0, 0, -0.5, -0.25]
 | 
				
			||||||
 | 
					  ]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user