feat: add display record command and refactor training command

This commit is contained in:
Cyrille Nofficial 2022-06-09 12:19:54 +02:00
parent 02db9e241e
commit 8b8d53af58
8 changed files with 255 additions and 47 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/cyrilix/robocar-tools/dkimpt" "github.com/cyrilix/robocar-tools/dkimpt"
"github.com/cyrilix/robocar-tools/part" "github.com/cyrilix/robocar-tools/part"
"github.com/cyrilix/robocar-tools/pkg/data" "github.com/cyrilix/robocar-tools/pkg/data"
"github.com/cyrilix/robocar-tools/pkg/display"
"github.com/cyrilix/robocar-tools/pkg/models" "github.com/cyrilix/robocar-tools/pkg/models"
"github.com/cyrilix/robocar-tools/pkg/train" "github.com/cyrilix/robocar-tools/pkg/train"
"github.com/cyrilix/robocar-tools/record" "github.com/cyrilix/robocar-tools/record"
@ -58,16 +59,28 @@ func main() {
flag.BoolVar(&debug, "debug", false, "Display debug logs") flag.BoolVar(&debug, "debug", false, "Display debug logs")
displayFlags := flag.NewFlagSet("display", flag.ExitOnError) displayFlags := flag.NewFlagSet("display", flag.ExitOnError)
cli.InitMqttFlagSet(displayFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
displayFlags.StringVar(&frameTopic, "mqtt-topic-frame", os.Getenv("MQTT_TOPIC_FRAME"), "Mqtt topic that contains frame to display, use MQTT_TOPIC_FRAME if args not set")
displayFlags.StringVar(&framePath, "frame-path", "", "Directory path where to read jpeg frame to inject in frame topic")
displayFlags.IntVar(&fps, "frame-per-second", 25, "Video frame per second of frame to publish")
displayFlags.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains detected objects, use MQTT_TOPIC_OBJECTS if args not set") displayFlags.Usage = func(){
displayFlags.BoolVar(&withObjects, "with-objects", false, "Display detected objects") fmt.Printf("Usage of %s %s:\n", os.Args[0], displayFlags.Name())
fmt.Printf(" camera\n \tLive from car camera\n")
fmt.Printf(" record\n \tLive from published records\n")
}
displayFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set") displayRecordFlags := flag.NewFlagSet("record", flag.ExitOnError)
displayFlags.BoolVar(&withRoad, "with-road", false, "Display detected road") cli.InitMqttFlagSet(displayRecordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
displayRecordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set")
displayCameraFlags := flag.NewFlagSet("camera", flag.ExitOnError)
cli.InitMqttFlagSet(displayCameraFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
displayCameraFlags.StringVar(&frameTopic, "mqtt-topic-frame", os.Getenv("MQTT_TOPIC_FRAME"), "Mqtt topic that contains frame to display, use MQTT_TOPIC_FRAME if args not set")
displayCameraFlags.StringVar(&framePath, "frame-path", "", "Directory path where to read jpeg frame to inject in frame topic")
displayCameraFlags.IntVar(&fps, "frame-per-second", 25, "Video frame per second of frame to publish")
displayCameraFlags.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains detected objects, use MQTT_TOPIC_OBJECTS if args not set")
displayCameraFlags.BoolVar(&withObjects, "with-objects", false, "Display detected objects")
displayCameraFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set")
displayCameraFlags.BoolVar(&withRoad, "with-road", false, "Display detected road")
recordFlags := flag.NewFlagSet("record", flag.ExitOnError) recordFlags := flag.NewFlagSet("record", flag.ExitOnError)
cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
@ -90,6 +103,7 @@ func main() {
} }
var modelPath, roleArn, trainJobName string var modelPath, roleArn, trainJobName string
var horizon int
var withFlipImage bool var withFlipImage bool
var trainImageHeight, trainImageWidth int var trainImageHeight, trainImageWidth int
var enableSpotTraining bool var enableSpotTraining bool
@ -105,6 +119,7 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -113,6 +128,9 @@ func main() {
trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set") trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set")
trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set") trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set") trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set")
trainArchiveFlags.IntVar(&trainImageWidth, "image-width", 0, "Resize image width")
trainArchiveFlags.IntVar(&trainImageHeight, "image-height", 0, "Resize image height")
trainArchiveFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainArchiveFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive") trainArchiveFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
@ -160,12 +178,32 @@ func main() {
displayFlags.PrintDefaults() displayFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
switch displayFlags.Arg(0) {
case displayRecordFlags.Name():
if err := displayRecordFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
displayRecordFlags.PrintDefaults()
os.Exit(0)
}
client, err := cli.Connect(mqttBroker, username, password, clientId)
if err != nil {
zap.S().Fatalf("unable to connect to mqtt bus: %v", err)
}
runDisplayRecord(client, recordTopic)
case displayCameraFlags.Name():
if err := displayCameraFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
displayCameraFlags.PrintDefaults()
os.Exit(0)
}
client, err := cli.Connect(mqttBroker, username, password, clientId) client, err := cli.Connect(mqttBroker, username, password, clientId)
if err != nil { if err != nil {
zap.S().Fatalf("unable to connect to mqtt bus: %v", err) zap.S().Fatalf("unable to connect to mqtt bus: %v", err)
} }
defer client.Disconnect(50) defer client.Disconnect(50)
runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, withObjects, withRoad) runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, withObjects, withRoad)
default:
displayFlags.PrintDefaults()
os.Exit(0)
}
case recordFlags.Name(): case recordFlags.Name():
if err := recordFlags.Parse(os.Args[2:]); err == flag.ErrHelp { if err := recordFlags.Parse(os.Args[2:]); err == flag.ErrHelp {
recordFlags.PrintDefaults() recordFlags.PrintDefaults()
@ -200,14 +238,13 @@ func main() {
trainingRunFlags.PrintDefaults() trainingRunFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, withFlipImage, modelPath, runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
trainImageHeight, trainImageWidth, enableSpotTraining)
case trainArchiveFlags.Name(): case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults() trainArchiveFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, withFlipImage) runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage)
default: default:
trainingFlags.PrintDefaults() trainingFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
@ -266,9 +303,9 @@ func runRecord(client mqtt.Client, recordsDir, recordTopic string) {
} }
} }
func runTrainArchive(basedir, archiveName string, sliceSize int, withFlipImage bool) { func runTrainArchive(basedir, archiveName string, sliceSize int, imgWidth, imgHeight int, horizon int, withFlipImage bool) {
err := data.WriteArchive(basedir, archiveName, sliceSize, withFlipImage) err := data.WriteArchive(basedir, archiveName, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
if err != nil { if err != nil {
zap.S().Fatalf("unable to build archive file %v: %v", archiveName, err) zap.S().Fatalf("unable to build archive file %v: %v", archiveName, err)
} }
@ -283,7 +320,16 @@ func runImportDonkeyRecords(basedir, destdir string) {
zap.S().Fatalf("unable to import files from %v to %v: %v", basedir, destdir, err) zap.S().Fatalf("unable to import files from %v to %v: %v", basedir, destdir, err)
} }
} }
func runDisplayRecord(client mqtt.Client, recordTopic string){
r := display.NewRecordDisplay(client, recordTopic)
defer r.Stop()
cli.HandleExit(r)
err := r.Start()
if err != nil {
zap.S().Fatalf("unable to start service: %v", err)
}
}
func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) { func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) {
if framePath != "" { if framePath != "" {
@ -310,8 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
} }
} }
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, withFlipImage bool, func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
outputModel string, imgHeight int, imgWidth int, enableSpotTraining bool) {
l := zap.S() l := zap.S()
if bucketName == "" { if bucketName == "" {
@ -335,8 +380,7 @@ func runTraining(bucketName string, ociImage string, roleArn string, jobName, da
} }
training := train.New(bucketName, ociImage, roleArn) training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, imgHeight, imgWidth, sliceSize, withFlipImage, err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
outputModel, enableSpotTraining)
if err != nil { if err != nil {
l.Fatalf("unable to run training: %v", err) l.Fatalf("unable to run training: %v", err)

View File

@ -50,7 +50,7 @@ func ImportDonkeyRecords(basedir string, destDir string) error {
return fmt.Errorf("unable to find index in cam image name %v: %v", img.Name(), err) return fmt.Errorf("unable to find index in cam image name %v: %v", img.Name(), err)
} }
zap.S().Debugf("found image with index %v", idx) zap.S().Debugf("found image with index %v", idx)
records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.RecorNameFormat, idx))) records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.FileNameFormat, idx)))
imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name())) imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name()))
} }

View File

@ -181,6 +181,28 @@ func (p *FramePart) drawRoad(img *gocv.Mat, road *events.RoadMessage) {
color.RGBA{R: 255, G: 0, B: 0, A: 128}, color.RGBA{R: 255, G: 0, B: 0, A: 128},
-1) -1)
p.drawRoadText(img, road)
}
func (p *FramePart) drawRoadText(img *gocv.Mat, road *events.RoadMessage) {
gocv.PutText(
img,
fmt.Sprintf("Confidence: %.3f", road.Ellipse.Confidence),
image.Point{X: 20, Y: 20},
gocv.FontHersheyPlain,
1.,
color.RGBA{R: 0, G: 255, B: 0, A: 255},
1,
)
gocv.PutText(
img,
fmt.Sprintf("Angle ellipse: %.3f", road.Ellipse.Angle),
image.Point{X: 20, Y: 40},
gocv.FontHersheyPlain,
1.,
color.RGBA{R: 0, G: 255, B: 0, A: 255},
1,
)
} }
func StopService(name string, client mqtt.Client, topics ...string) { func StopService(name string, client mqtt.Client, topics ...string) {

View File

@ -19,8 +19,8 @@ import (
var camSubDir = "cam" var camSubDir = "cam"
func WriteArchive(basedir string, archiveName string, sliceSize int, flipImages bool) error { func WriteArchive(basedir string, archiveName string, sliceSize int, imgWidth, imgHeight int, horizon int, flipImages bool) error {
content, err := BuildArchive(basedir, sliceSize, flipImages) content, err := BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, flipImages)
if err != nil { if err != nil {
return fmt.Errorf("unable to build archive: %w", err) return fmt.Errorf("unable to build archive: %w", err)
} }
@ -34,7 +34,7 @@ func WriteArchive(basedir string, archiveName string, sliceSize int, flipImages
return nil return nil
} }
func BuildArchive(basedir string, sliceSize int, flipImages bool) ([]byte, error) { func BuildArchive(basedir string, sliceSize int, imgWidth, imgHeight int, horizon int, flipImages bool) ([]byte, error) {
l := zap.S() l := zap.S()
l.Infof("build zip archive from %s\n", basedir) l.Infof("build zip archive from %s\n", basedir)
dirItems, err := ioutil.ReadDir(basedir) dirItems, err := ioutil.ReadDir(basedir)
@ -59,7 +59,7 @@ func BuildArchive(basedir string, sliceSize int, flipImages bool) ([]byte, error
return nil, fmt.Errorf("unable to find index in cam image name %v: %w", img.Name(), err) return nil, fmt.Errorf("unable to find index in cam image name %v: %w", img.Name(), err)
} }
l.Debugf("found image with index %v", idx) l.Debugf("found image with index %v", idx)
records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.RecorNameFormat, idx))) records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.FileNameFormat, idx)))
imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name())) imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name()))
} }
} }
@ -73,12 +73,12 @@ func BuildArchive(basedir string, sliceSize int, flipImages bool) ([]byte, error
// Create a new zip archive. // Create a new zip archive.
w := zip.NewWriter(buf) w := zip.NewWriter(buf)
err = buildArchiveContent(w, imgCams, records, false) err = buildArchiveContent(w, imgCams, records, imgWidth, imgHeight, horizon, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to build archive: %w", err) return nil, fmt.Errorf("unable to build archive: %w", err)
} }
if flipImages { if flipImages {
err = buildArchiveContent(w, imgCams, records, true) err = buildArchiveContent(w, imgCams, records, imgWidth, imgHeight, horizon, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to build archive: %w", err) return nil, fmt.Errorf("unable to build archive: %w", err)
} }
@ -132,13 +132,13 @@ func findNamedMatches(regex *regexp.Regexp, str string) map[string]string {
return results return results
} }
func buildArchiveContent(w *zip.Writer, imgFiles []string, recordFiles []string, withFlipImages bool) error { func buildArchiveContent(w *zip.Writer, imgFiles []string, recordFiles []string, imgWidth, imgHeight int, horizon int, withFlipImages bool) error {
err := addJsonFiles(recordFiles, imgFiles, withFlipImages, w) err := addJsonFiles(recordFiles, imgFiles, withFlipImages, w)
if err != nil { if err != nil {
return fmt.Errorf("unable to write json files in zip archive: %w", err) return fmt.Errorf("unable to write json files in zip archive: %w", err)
} }
err = addCamImages(imgFiles, withFlipImages, w) err = addCamImages(imgFiles, withFlipImages, w, imgWidth, imgHeight, horizon)
if err != nil { if err != nil {
return fmt.Errorf("unable to cam files in zip archive: %w", err) return fmt.Errorf("unable to cam files in zip archive: %w", err)
} }
@ -146,7 +146,7 @@ func buildArchiveContent(w *zip.Writer, imgFiles []string, recordFiles []string,
return err return err
} }
func addCamImages(imgFiles []string, flipImage bool, w *zip.Writer) error { func addCamImages(imgFiles []string, flipImage bool, w *zip.Writer, imgWidth, imgHeight int, horizon int) error {
for _, im := range imgFiles { for _, im := range imgFiles {
imgContent, err := ioutil.ReadFile(im) imgContent, err := ioutil.ReadFile(im)
if err != nil { if err != nil {
@ -154,18 +154,30 @@ func addCamImages(imgFiles []string, flipImage bool, w *zip.Writer) error {
} }
_, imgName := path.Split(im) _, imgName := path.Split(im)
if flipImage { if flipImage || imgWidth > 0 && imgHeight > 0 || horizon > 0 {
img, _, err := image.Decode(bytes.NewReader(imgContent)) img, _, err := image.Decode(bytes.NewReader(imgContent))
if err != nil { if err != nil {
zap.S().Fatalf("unable to decode peg image: %v", err) zap.S().Fatalf("unable to decode jpeg image: %v", err)
} }
imgFlip := imaging.FlipH(img)
var bytesBuff bytes.Buffer
err = jpeg.Encode(&bytesBuff, imgFlip, nil)
imgContent = bytesBuff.Bytes() if imgWidth > 0 && imgHeight > 0 {
bounds := img.Bounds()
if bounds.Dx() != imgWidth || bounds.Dy() != imgWidth {
zap.S().Debugf("resize image %v from %dx%d to %dx%d", im, bounds.Dx(), bounds.Dy(), imgWidth, imgHeight)
img = imaging.Resize(img, imgWidth, imgHeight, imaging.NearestNeighbor)
}
}
if flipImage {
img = imaging.FlipH(img)
imgName = fmt.Sprintf("flip_%s", imgName) imgName = fmt.Sprintf("flip_%s", imgName)
} }
if horizon > 0 {
img = imaging.Crop(img, image.Rect(0, horizon, img.Bounds().Dx(), img.Bounds().Dy()))
}
var bytesBuff bytes.Buffer
err = jpeg.Encode(&bytesBuff, img, nil)
imgContent = bytesBuff.Bytes()
}
err = addToArchive(w, imgName, imgContent) err = addToArchive(w, imgName, imgContent)
if err != nil { if err != nil {

View File

@ -29,7 +29,7 @@ func TestBuildArchive(t *testing.T) {
expectedRecordFiles, expectedImgFiles := expectedFiles() expectedRecordFiles, expectedImgFiles := expectedFiles()
err = WriteArchive("testdata", archive, 0) err = WriteArchive("testdata", archive, 0, 160, 120, 0, false)
if err != nil { if err != nil {
t.Errorf("unable to build archive: %v", err) t.Errorf("unable to build archive: %v", err)
} }

130
pkg/display/record.go Normal file
View File

@ -0,0 +1,130 @@
package display
import (
"fmt"
"github.com/cyrilix/robocar-protobuf/go/events"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"gocv.io/x/gocv"
"image"
"image/color"
)
func NewRecordDisplay(client mqtt.Client, recordTopic string) *Record {
return &Record{
client: client,
recordTopic: recordTopic,
window: gocv.NewWindow("recordTopic"),
recordChan: make(chan *events.RecordMessage),
cancel: make(chan interface{}),
}
}
type Record struct {
client mqtt.Client
recordTopic string
window *gocv.Window
recordChan chan *events.RecordMessage
cancel chan interface{}
}
func (r *Record) Start() error {
if err := r.registerCallbacks(); err != nil {
return fmt.Errorf("unable to start service: %v", err)
}
var rec *events.RecordMessage
var objectsMsg events.ObjectsMessage
var roadMsg events.RoadMessage
for {
select {
case newRecord := <-r.recordChan:
rec = newRecord
case <-r.cancel:
return nil
}
go r.drawRecord(rec, &objectsMsg, &roadMsg)
}
}
func (r *Record) Stop() {
defer r.window.Close()
close(r.cancel)
StopService("record-display", r.client, r.recordTopic)
}
func (r *Record) onRecord(_ mqtt.Client, message mqtt.Message) {
var msg events.RecordMessage
err := proto.Unmarshal(message.Payload(), &msg)
if err != nil {
zap.S().Errorf("unable to unmarshal protobuf FrameMessage: %v", err)
return
}
message.Ack()
r.recordChan <- &msg
}
func (r *Record) registerCallbacks() error {
err := RegisterCallback(r.client, r.recordTopic, r.onRecord)
if err != nil {
return err
}
return nil
}
func (r *Record) drawRecord(rec *events.RecordMessage, objects *events.ObjectsMessage, road *events.RoadMessage) {
img, err := gocv.IMDecode(rec.GetFrame().GetFrame(), gocv.IMReadUnchanged)
if err != nil {
zap.S().Errorf("unable to decode image: %v", err)
return
}
defer img.Close()
steering := rec.GetSteering().GetSteering()
r.drawSteering(&img, steering)
r.window.IMShow(img)
r.window.WaitKey(1)
}
func (r *Record) drawSteering(img *gocv.Mat, steering float32) {
gocv.PutText(
img,
fmt.Sprintf("Steering: %.3f", steering),
image.Point{X: 20, Y: 20},
gocv.FontHersheyPlain,
1.,
color.RGBA{R: 0, G: 255, B: 0, A: 255},
1,
)
}
func StopService(name string, client mqtt.Client, topics ...string) {
zap.S().Infof("Stop %s service", name)
token := client.Unsubscribe(topics...)
token.Wait()
if token.Error() != nil {
zap.S().Errorf("unable to unsubscribe service: %v", token.Error())
}
client.Disconnect(50)
}
func RegisterCallback(client mqtt.Client, topic string, callback mqtt.MessageHandler) error {
zap.S().Infof("Register callback on topic %v", topic)
token := client.Subscribe(topic, 0, callback)
token.Wait()
if token.Error() != nil {
return fmt.Errorf("unable to register callback on topic %s: %v", topic, token.Error())
}
return nil
}

View File

@ -35,10 +35,10 @@ type Training struct {
outputBucket string outputBucket string
} }
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgHeight, imgWidth int, sliceSize int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
l.Infof("run training with data from %s", basedir) l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, withFlipImage) archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
if err != nil { if err != nil {
return fmt.Errorf("unable to build data archive: %w", err) return fmt.Errorf("unable to build data archive: %w", err)
} }

View File

@ -33,7 +33,7 @@ type Recorder struct {
cancel chan interface{} cancel chan interface{}
} }
var RecorNameFormat = "record_%s.json" var FileNameFormat = "record_%s.json"
func (r *Recorder) Start() error { func (r *Recorder) Start() error {
err := service.RegisterCallback(r.client, r.recordTopic, r.onRecordMsg) err := service.RegisterCallback(r.client, r.recordTopic, r.onRecordMsg)
@ -75,7 +75,7 @@ func (r *Recorder) onRecordMsg(_ mqtt.Client, message mqtt.Message) {
} }
jsonDir := fmt.Sprintf("%s/", recordDir) jsonDir := fmt.Sprintf("%s/", recordDir)
recordName := fmt.Sprintf("%s/%s", jsonDir, fmt.Sprintf(RecorNameFormat, msg.GetFrame().GetId().GetId())) recordName := fmt.Sprintf("%s/%s", jsonDir, fmt.Sprintf(FileNameFormat, msg.GetFrame().GetId().GetId()))
err = os.MkdirAll(jsonDir, os.FileMode(0755)) err = os.MkdirAll(jsonDir, os.FileMode(0755))
if err != nil { if err != nil {
l.Errorf("unable to create %v directory: %v", jsonDir, err) l.Errorf("unable to create %v directory: %v", jsonDir, err)