From c4ff4b46b075de698b996eb3dbeef3768e93bbf8 Mon Sep 17 00:00:00 2001 From: Cyrille Nofficial Date: Thu, 6 Oct 2022 14:05:07 +0200 Subject: [PATCH] feat(training): add flag toconfigure model type training --- cmd/rc-tools/rc-tools.go | 13 +++++++---- pkg/train/train.go | 47 +++++++++++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/cmd/rc-tools/rc-tools.go b/cmd/rc-tools/rc-tools.go index 0542a98..bdb7c0d 100644 --- a/cmd/rc-tools/rc-tools.go +++ b/cmd/rc-tools/rc-tools.go @@ -102,7 +102,7 @@ func main() { fmt.Printf(" run\n \tRun training job\n") } - var modelPath, roleArn, trainJobName string + var modelPath, roleArn, trainJobName, modelType string var horizon int var withFlipImage bool var trainImageHeight, trainImageWidth int @@ -120,6 +120,7 @@ func main() { trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") + trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) @@ -237,7 +238,7 @@ func main() { trainingRunFlags.PrintDefaults() os.Exit(0) } - runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) + runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) case trainArchiveFlags.Name(): if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { trainArchiveFlags.PrintDefaults() @@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int } } -func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { +func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { l := zap.S() if bucketName == "" { @@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed") } + if modelType == train.ModelTypeUnknown { + l.Fatalf("invalid model type: %v", modelType) + } + training := train.New(bucketName, ociImage, roleArn) - err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) + err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) if err != nil { l.Fatalf("unable to run training: %v", err) diff --git a/pkg/train/train.go b/pkg/train/train.go index 704faab..d982453 100644 --- a/pkg/train/train.go +++ b/pkg/train/train.go @@ -12,9 +12,40 @@ import ( "github.com/cyrilix/robocar-tools/pkg/models" "go.uber.org/zap" "strconv" + "strings" "time" ) +type ModelType int + +func ParseModelType(s string) ModelType { + switch strings.ToLower(s) { + case "categorical": + return ModelTypeCategorical + case "linear": + return ModelTypeLinear + default: + return ModelTypeUnknown + } +} + +func (m ModelType) String() string { + switch m { + case ModelTypeCategorical: + return "categorical" + case ModelTypeLinear: + return "linear" + default: + return "unknown" + } +} + +const ( + ModelTypeUnknown ModelType = iota + ModelTypeCategorical + ModelTypeLinear +) + func New(bucketName string, ociImage, roleArn string) *Training { return &Training{ config: awsutils.MustLoadConfig(), @@ -35,7 +66,7 @@ type Training struct { outputBucket string } -func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { +func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { l := zap.S() l.Infof("run training with data from %s", basedir) archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) @@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid } l.Info("") - err = t.runTraining( - ctx, - jobName, - sliceSize, - imgHeight, - imgWidth, - enableSpotTraining, - ) + err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType) if err != nil { return fmt.Errorf("unable to run training: %w", err) } @@ -95,7 +119,7 @@ func List(bucketName string) error { return nil } -func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error { +func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error { l := zap.S() client := sagemaker.NewFromConfig(awsutils.MustLoadConfig()) l.Infof("Start training job '%s'", jobName) @@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in "slide_size": strconv.Itoa(slideSize), "img_height": strconv.Itoa(imgHeight), "img_width": strconv.Itoa(imgWidth), + "batch_size": strconv.Itoa(32), + "model_type": modelType.String(), + "horizon": strconv.Itoa(horizon), }, InputDataConfig: []types.Channel{ {