From 8b8d53af5880fa0136cd20f544b13b27fa089306 Mon Sep 17 00:00:00 2001 From: Cyrille Nofficial Date: Thu, 9 Jun 2022 12:19:54 +0200 Subject: [PATCH] feat: add display record command and refactor training command --- cmd/rc-tools/rc-tools.go | 88 +++++++++++++++++++------- dkimpt/import.go | 2 +- part/part.go | 22 +++++++ pkg/data/data.go | 42 ++++++++----- pkg/data/data_test.go | 2 +- pkg/display/record.go | 130 +++++++++++++++++++++++++++++++++++++++ pkg/train/train.go | 12 ++-- record/record.go | 4 +- 8 files changed, 255 insertions(+), 47 deletions(-) create mode 100644 pkg/display/record.go diff --git a/cmd/rc-tools/rc-tools.go b/cmd/rc-tools/rc-tools.go index 3c07417..7dca447 100644 --- a/cmd/rc-tools/rc-tools.go +++ b/cmd/rc-tools/rc-tools.go @@ -8,6 +8,7 @@ import ( "github.com/cyrilix/robocar-tools/dkimpt" "github.com/cyrilix/robocar-tools/part" "github.com/cyrilix/robocar-tools/pkg/data" + "github.com/cyrilix/robocar-tools/pkg/display" "github.com/cyrilix/robocar-tools/pkg/models" "github.com/cyrilix/robocar-tools/pkg/train" "github.com/cyrilix/robocar-tools/record" @@ -58,16 +59,28 @@ func main() { flag.BoolVar(&debug, "debug", false, "Display debug logs") displayFlags := flag.NewFlagSet("display", flag.ExitOnError) - cli.InitMqttFlagSet(displayFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) - displayFlags.StringVar(&frameTopic, "mqtt-topic-frame", os.Getenv("MQTT_TOPIC_FRAME"), "Mqtt topic that contains frame to display, use MQTT_TOPIC_FRAME if args not set") - displayFlags.StringVar(&framePath, "frame-path", "", "Directory path where to read jpeg frame to inject in frame topic") - displayFlags.IntVar(&fps, "frame-per-second", 25, "Video frame per second of frame to publish") - displayFlags.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains detected objects, use MQTT_TOPIC_OBJECTS if args not set") - displayFlags.BoolVar(&withObjects, "with-objects", false, "Display detected objects") + displayFlags.Usage = func(){ + fmt.Printf("Usage of %s %s:\n", os.Args[0], displayFlags.Name()) + fmt.Printf(" camera\n \tLive from car camera\n") + fmt.Printf(" record\n \tLive from published records\n") + } - displayFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set") - displayFlags.BoolVar(&withRoad, "with-road", false, "Display detected road") + displayRecordFlags := flag.NewFlagSet("record", flag.ExitOnError) + cli.InitMqttFlagSet(displayRecordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) + displayRecordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set") + + displayCameraFlags := flag.NewFlagSet("camera", flag.ExitOnError) + cli.InitMqttFlagSet(displayCameraFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) + displayCameraFlags.StringVar(&frameTopic, "mqtt-topic-frame", os.Getenv("MQTT_TOPIC_FRAME"), "Mqtt topic that contains frame to display, use MQTT_TOPIC_FRAME if args not set") + displayCameraFlags.StringVar(&framePath, "frame-path", "", "Directory path where to read jpeg frame to inject in frame topic") + displayCameraFlags.IntVar(&fps, "frame-per-second", 25, "Video frame per second of frame to publish") + + displayCameraFlags.StringVar(&objectsTopic, "mqtt-topic-objects", os.Getenv("MQTT_TOPIC_OBJECTS"), "Mqtt topic that contains detected objects, use MQTT_TOPIC_OBJECTS if args not set") + displayCameraFlags.BoolVar(&withObjects, "with-objects", false, "Display detected objects") + + displayCameraFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set") + displayCameraFlags.BoolVar(&withRoad, "with-road", false, "Display detected road") recordFlags := flag.NewFlagSet("record", flag.ExitOnError) cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) @@ -90,6 +103,7 @@ func main() { } var modelPath, roleArn, trainJobName string + var horizon int var withFlipImage bool var trainImageHeight, trainImageWidth int var enableSpotTraining bool @@ -105,6 +119,7 @@ func main() { trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") + trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) @@ -113,6 +128,9 @@ func main() { trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set") trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set") trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set") + trainArchiveFlags.IntVar(&trainImageWidth, "image-width", 0, "Resize image width") + trainArchiveFlags.IntVar(&trainImageHeight, "image-height", 0, "Resize image height") + trainArchiveFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") trainArchiveFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive") @@ -160,12 +178,32 @@ func main() { displayFlags.PrintDefaults() os.Exit(0) } - client, err := cli.Connect(mqttBroker, username, password, clientId) - if err != nil { - zap.S().Fatalf("unable to connect to mqtt bus: %v", err) + switch displayFlags.Arg(0) { + case displayRecordFlags.Name(): + if err := displayRecordFlags.Parse(os.Args[3:]); err == flag.ErrHelp { + displayRecordFlags.PrintDefaults() + os.Exit(0) + } + client, err := cli.Connect(mqttBroker, username, password, clientId) + if err != nil { + zap.S().Fatalf("unable to connect to mqtt bus: %v", err) + } + runDisplayRecord(client, recordTopic) + case displayCameraFlags.Name(): + if err := displayCameraFlags.Parse(os.Args[3:]); err == flag.ErrHelp { + displayCameraFlags.PrintDefaults() + os.Exit(0) + } + client, err := cli.Connect(mqttBroker, username, password, clientId) + if err != nil { + zap.S().Fatalf("unable to connect to mqtt bus: %v", err) + } + defer client.Disconnect(50) + runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, withObjects, withRoad) + default: + displayFlags.PrintDefaults() + os.Exit(0) } - defer client.Disconnect(50) - runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, withObjects, withRoad) case recordFlags.Name(): if err := recordFlags.Parse(os.Args[2:]); err == flag.ErrHelp { recordFlags.PrintDefaults() @@ -200,14 +238,13 @@ func main() { trainingRunFlags.PrintDefaults() os.Exit(0) } - runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, withFlipImage, modelPath, - trainImageHeight, trainImageWidth, enableSpotTraining) + runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) case trainArchiveFlags.Name(): if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { trainArchiveFlags.PrintDefaults() os.Exit(0) } - runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, withFlipImage) + runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage) default: trainingFlags.PrintDefaults() os.Exit(0) @@ -266,9 +303,9 @@ func runRecord(client mqtt.Client, recordsDir, recordTopic string) { } } -func runTrainArchive(basedir, archiveName string, sliceSize int, withFlipImage bool) { +func runTrainArchive(basedir, archiveName string, sliceSize int, imgWidth, imgHeight int, horizon int, withFlipImage bool) { - err := data.WriteArchive(basedir, archiveName, sliceSize, withFlipImage) + err := data.WriteArchive(basedir, archiveName, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) if err != nil { zap.S().Fatalf("unable to build archive file %v: %v", archiveName, err) } @@ -283,7 +320,16 @@ func runImportDonkeyRecords(basedir, destdir string) { zap.S().Fatalf("unable to import files from %v to %v: %v", basedir, destdir, err) } } +func runDisplayRecord(client mqtt.Client, recordTopic string){ + r := display.NewRecordDisplay(client, recordTopic) + defer r.Stop() + cli.HandleExit(r) + err := r.Start() + if err != nil { + zap.S().Fatalf("unable to start service: %v", err) + } +} func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) { if framePath != "" { @@ -310,8 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int } } -func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, withFlipImage bool, - outputModel string, imgHeight int, imgWidth int, enableSpotTraining bool) { +func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { l := zap.S() if bucketName == "" { @@ -335,8 +380,7 @@ func runTraining(bucketName string, ociImage string, roleArn string, jobName, da } training := train.New(bucketName, ociImage, roleArn) - err := training.TrainDir(context.Background(), jobName, dataDir, imgHeight, imgWidth, sliceSize, withFlipImage, - outputModel, enableSpotTraining) + err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) if err != nil { l.Fatalf("unable to run training: %v", err) diff --git a/dkimpt/import.go b/dkimpt/import.go index b6f2dba..a1ef565 100644 --- a/dkimpt/import.go +++ b/dkimpt/import.go @@ -50,7 +50,7 @@ func ImportDonkeyRecords(basedir string, destDir string) error { return fmt.Errorf("unable to find index in cam image name %v: %v", img.Name(), err) } zap.S().Debugf("found image with index %v", idx) - records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.RecorNameFormat, idx))) + records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.FileNameFormat, idx))) imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name())) } diff --git a/part/part.go b/part/part.go index 7a50481..5dda929 100644 --- a/part/part.go +++ b/part/part.go @@ -181,6 +181,28 @@ func (p *FramePart) drawRoad(img *gocv.Mat, road *events.RoadMessage) { color.RGBA{R: 255, G: 0, B: 0, A: 128}, -1) + p.drawRoadText(img, road) + +} +func (p *FramePart) drawRoadText(img *gocv.Mat, road *events.RoadMessage) { + gocv.PutText( + img, + fmt.Sprintf("Confidence: %.3f", road.Ellipse.Confidence), + image.Point{X: 20, Y: 20}, + gocv.FontHersheyPlain, + 1., + color.RGBA{R: 0, G: 255, B: 0, A: 255}, + 1, + ) + gocv.PutText( + img, + fmt.Sprintf("Angle ellipse: %.3f", road.Ellipse.Angle), + image.Point{X: 20, Y: 40}, + gocv.FontHersheyPlain, + 1., + color.RGBA{R: 0, G: 255, B: 0, A: 255}, + 1, + ) } func StopService(name string, client mqtt.Client, topics ...string) { diff --git a/pkg/data/data.go b/pkg/data/data.go index cc068d6..c6493f8 100644 --- a/pkg/data/data.go +++ b/pkg/data/data.go @@ -19,8 +19,8 @@ import ( var camSubDir = "cam" -func WriteArchive(basedir string, archiveName string, sliceSize int, flipImages bool) error { - content, err := BuildArchive(basedir, sliceSize, flipImages) +func WriteArchive(basedir string, archiveName string, sliceSize int, imgWidth, imgHeight int, horizon int, flipImages bool) error { + content, err := BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, flipImages) if err != nil { return fmt.Errorf("unable to build archive: %w", err) } @@ -34,7 +34,7 @@ func WriteArchive(basedir string, archiveName string, sliceSize int, flipImages return nil } -func BuildArchive(basedir string, sliceSize int, flipImages bool) ([]byte, error) { +func BuildArchive(basedir string, sliceSize int, imgWidth, imgHeight int, horizon int, flipImages bool) ([]byte, error) { l := zap.S() l.Infof("build zip archive from %s\n", basedir) dirItems, err := ioutil.ReadDir(basedir) @@ -59,7 +59,7 @@ func BuildArchive(basedir string, sliceSize int, flipImages bool) ([]byte, error return nil, fmt.Errorf("unable to find index in cam image name %v: %w", img.Name(), err) } l.Debugf("found image with index %v", idx) - records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.RecorNameFormat, idx))) + records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.FileNameFormat, idx))) imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name())) } } @@ -73,12 +73,12 @@ func BuildArchive(basedir string, sliceSize int, flipImages bool) ([]byte, error // Create a new zip archive. w := zip.NewWriter(buf) - err = buildArchiveContent(w, imgCams, records, false) + err = buildArchiveContent(w, imgCams, records, imgWidth, imgHeight, horizon, false) if err != nil { return nil, fmt.Errorf("unable to build archive: %w", err) } if flipImages { - err = buildArchiveContent(w, imgCams, records, true) + err = buildArchiveContent(w, imgCams, records, imgWidth, imgHeight, horizon, true) if err != nil { return nil, fmt.Errorf("unable to build archive: %w", err) } @@ -132,13 +132,13 @@ func findNamedMatches(regex *regexp.Regexp, str string) map[string]string { return results } -func buildArchiveContent(w *zip.Writer, imgFiles []string, recordFiles []string, withFlipImages bool) error { +func buildArchiveContent(w *zip.Writer, imgFiles []string, recordFiles []string, imgWidth, imgHeight int, horizon int, withFlipImages bool) error { err := addJsonFiles(recordFiles, imgFiles, withFlipImages, w) if err != nil { return fmt.Errorf("unable to write json files in zip archive: %w", err) } - err = addCamImages(imgFiles, withFlipImages, w) + err = addCamImages(imgFiles, withFlipImages, w, imgWidth, imgHeight, horizon) if err != nil { return fmt.Errorf("unable to cam files in zip archive: %w", err) } @@ -146,7 +146,7 @@ func buildArchiveContent(w *zip.Writer, imgFiles []string, recordFiles []string, return err } -func addCamImages(imgFiles []string, flipImage bool, w *zip.Writer) error { +func addCamImages(imgFiles []string, flipImage bool, w *zip.Writer, imgWidth, imgHeight int, horizon int) error { for _, im := range imgFiles { imgContent, err := ioutil.ReadFile(im) if err != nil { @@ -154,17 +154,29 @@ func addCamImages(imgFiles []string, flipImage bool, w *zip.Writer) error { } _, imgName := path.Split(im) - if flipImage { + if flipImage || imgWidth > 0 && imgHeight > 0 || horizon > 0 { img, _, err := image.Decode(bytes.NewReader(imgContent)) if err != nil { - zap.S().Fatalf("unable to decode peg image: %v", err) + zap.S().Fatalf("unable to decode jpeg image: %v", err) } - imgFlip := imaging.FlipH(img) - var bytesBuff bytes.Buffer - err = jpeg.Encode(&bytesBuff, imgFlip, nil) + if imgWidth > 0 && imgHeight > 0 { + bounds := img.Bounds() + if bounds.Dx() != imgWidth || bounds.Dy() != imgWidth { + zap.S().Debugf("resize image %v from %dx%d to %dx%d", im, bounds.Dx(), bounds.Dy(), imgWidth, imgHeight) + img = imaging.Resize(img, imgWidth, imgHeight, imaging.NearestNeighbor) + } + } + if flipImage { + img = imaging.FlipH(img) + imgName = fmt.Sprintf("flip_%s", imgName) + } + if horizon > 0 { + img = imaging.Crop(img, image.Rect(0, horizon, img.Bounds().Dx(), img.Bounds().Dy())) + } + var bytesBuff bytes.Buffer + err = jpeg.Encode(&bytesBuff, img, nil) imgContent = bytesBuff.Bytes() - imgName = fmt.Sprintf("flip_%s", imgName) } err = addToArchive(w, imgName, imgContent) diff --git a/pkg/data/data_test.go b/pkg/data/data_test.go index 0cffacd..c0c925d 100644 --- a/pkg/data/data_test.go +++ b/pkg/data/data_test.go @@ -29,7 +29,7 @@ func TestBuildArchive(t *testing.T) { expectedRecordFiles, expectedImgFiles := expectedFiles() - err = WriteArchive("testdata", archive, 0) + err = WriteArchive("testdata", archive, 0, 160, 120, 0, false) if err != nil { t.Errorf("unable to build archive: %v", err) } diff --git a/pkg/display/record.go b/pkg/display/record.go new file mode 100644 index 0000000..f4576da --- /dev/null +++ b/pkg/display/record.go @@ -0,0 +1,130 @@ +package display + +import ( + "fmt" + "github.com/cyrilix/robocar-protobuf/go/events" + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/golang/protobuf/proto" + "go.uber.org/zap" + "gocv.io/x/gocv" + "image" + "image/color" +) + +func NewRecordDisplay(client mqtt.Client, recordTopic string) *Record { + return &Record{ + client: client, + recordTopic: recordTopic, + window: gocv.NewWindow("recordTopic"), + recordChan: make(chan *events.RecordMessage), + cancel: make(chan interface{}), + } + +} + +type Record struct { + client mqtt.Client + recordTopic string + + window *gocv.Window + + recordChan chan *events.RecordMessage + cancel chan interface{} +} + +func (r *Record) Start() error { + if err := r.registerCallbacks(); err != nil { + return fmt.Errorf("unable to start service: %v", err) + } + + var rec *events.RecordMessage + var objectsMsg events.ObjectsMessage + var roadMsg events.RoadMessage + + for { + select { + case newRecord := <-r.recordChan: + rec = newRecord + case <-r.cancel: + return nil + } + go r.drawRecord(rec, &objectsMsg, &roadMsg) + } +} + +func (r *Record) Stop() { + defer r.window.Close() + + close(r.cancel) + + StopService("record-display", r.client, r.recordTopic) +} + +func (r *Record) onRecord(_ mqtt.Client, message mqtt.Message) { + var msg events.RecordMessage + err := proto.Unmarshal(message.Payload(), &msg) + if err != nil { + zap.S().Errorf("unable to unmarshal protobuf FrameMessage: %v", err) + return + } + message.Ack() + r.recordChan <- &msg +} + +func (r *Record) registerCallbacks() error { + err := RegisterCallback(r.client, r.recordTopic, r.onRecord) + if err != nil { + return err + } + + return nil +} + +func (r *Record) drawRecord(rec *events.RecordMessage, objects *events.ObjectsMessage, road *events.RoadMessage) { + + img, err := gocv.IMDecode(rec.GetFrame().GetFrame(), gocv.IMReadUnchanged) + if err != nil { + zap.S().Errorf("unable to decode image: %v", err) + return + } + defer img.Close() + + steering := rec.GetSteering().GetSteering() + r.drawSteering(&img, steering) + + r.window.IMShow(img) + r.window.WaitKey(1) +} + + +func (r *Record) drawSteering(img *gocv.Mat, steering float32) { + gocv.PutText( + img, + fmt.Sprintf("Steering: %.3f", steering), + image.Point{X: 20, Y: 20}, + gocv.FontHersheyPlain, + 1., + color.RGBA{R: 0, G: 255, B: 0, A: 255}, + 1, + ) +} + +func StopService(name string, client mqtt.Client, topics ...string) { + zap.S().Infof("Stop %s service", name) + token := client.Unsubscribe(topics...) + token.Wait() + if token.Error() != nil { + zap.S().Errorf("unable to unsubscribe service: %v", token.Error()) + } + client.Disconnect(50) +} + +func RegisterCallback(client mqtt.Client, topic string, callback mqtt.MessageHandler) error { + zap.S().Infof("Register callback on topic %v", topic) + token := client.Subscribe(topic, 0, callback) + token.Wait() + if token.Error() != nil { + return fmt.Errorf("unable to register callback on topic %s: %v", topic, token.Error()) + } + return nil +} diff --git a/pkg/train/train.go b/pkg/train/train.go index 41f057d..704faab 100644 --- a/pkg/train/train.go +++ b/pkg/train/train.go @@ -35,10 +35,10 @@ type Training struct { outputBucket string } -func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgHeight, imgWidth int, sliceSize int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { +func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { l := zap.S() l.Infof("run training with data from %s", basedir) - archive, err := data.BuildArchive(basedir, sliceSize, withFlipImage) + archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) if err != nil { return fmt.Errorf("unable to build data archive: %w", err) } @@ -110,7 +110,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in S3OutputPath: aws.String(t.outputBucket), }, ResourceConfig: &types.ResourceConfig{ - InstanceCount: 1, + InstanceCount: 1, //InstanceType: types.TrainingInstanceTypeMlP2Xlarge, InstanceType: types.TrainingInstanceTypeMlG4dnXlarge, VolumeSizeInGB: 1, @@ -168,7 +168,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in } switch status.TrainingJobStatus { case types.TrainingJobStatusInProgress: - l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions) - 1].StatusMessage) + l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions)-1].StatusMessage) continue case types.TrainingJobStatusFailed: return fmt.Errorf("job %s finished with status %v", jobName, status.TrainingJobStatus) @@ -198,5 +198,5 @@ func ListJob(ctx context.Context) error { for _, job := range jobs.TrainingJobSummaries { fmt.Printf("%s\t\t%s\n", *job.TrainingJobName, job.TrainingJobStatus) } - return nil -} \ No newline at end of file + return nil +} diff --git a/record/record.go b/record/record.go index 38e614d..d721a01 100644 --- a/record/record.go +++ b/record/record.go @@ -33,7 +33,7 @@ type Recorder struct { cancel chan interface{} } -var RecorNameFormat = "record_%s.json" +var FileNameFormat = "record_%s.json" func (r *Recorder) Start() error { err := service.RegisterCallback(r.client, r.recordTopic, r.onRecordMsg) @@ -75,7 +75,7 @@ func (r *Recorder) onRecordMsg(_ mqtt.Client, message mqtt.Message) { } jsonDir := fmt.Sprintf("%s/", recordDir) - recordName := fmt.Sprintf("%s/%s", jsonDir, fmt.Sprintf(RecorNameFormat, msg.GetFrame().GetId().GetId())) + recordName := fmt.Sprintf("%s/%s", jsonDir, fmt.Sprintf(FileNameFormat, msg.GetFrame().GetId().GetId())) err = os.MkdirAll(jsonDir, os.FileMode(0755)) if err != nil { l.Errorf("unable to create %v directory: %v", jsonDir, err)