feat(training): add flag toconfigure model type training

This commit is contained in:
Cyrille Nofficial 2022-10-06 14:05:07 +02:00
parent b7b4bd76a0
commit c4ff4b46b0
2 changed files with 46 additions and 14 deletions

View File

@ -102,7 +102,7 @@ func main() {
fmt.Printf(" run\n \tRun training job\n") fmt.Printf(" run\n \tRun training job\n")
} }
var modelPath, roleArn, trainJobName string var modelPath, roleArn, trainJobName, modelType string
var horizon int var horizon int
var withFlipImage bool var withFlipImage bool
var trainImageHeight, trainImageWidth int var trainImageHeight, trainImageWidth int
@ -120,6 +120,7 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -237,7 +238,7 @@ func main() {
trainingRunFlags.PrintDefaults() trainingRunFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
case trainArchiveFlags.Name(): case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults() trainArchiveFlags.PrintDefaults()
@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
} }
} }
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
l := zap.S() l := zap.S()
if bucketName == "" { if bucketName == "" {
@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed") l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
} }
if modelType == train.ModelTypeUnknown {
l.Fatalf("invalid model type: %v", modelType)
}
training := train.New(bucketName, ociImage, roleArn) training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
if err != nil { if err != nil {
l.Fatalf("unable to run training: %v", err) l.Fatalf("unable to run training: %v", err)

View File

@ -12,9 +12,40 @@ import (
"github.com/cyrilix/robocar-tools/pkg/models" "github.com/cyrilix/robocar-tools/pkg/models"
"go.uber.org/zap" "go.uber.org/zap"
"strconv" "strconv"
"strings"
"time" "time"
) )
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
func New(bucketName string, ociImage, roleArn string) *Training { func New(bucketName string, ociImage, roleArn string) *Training {
return &Training{ return &Training{
config: awsutils.MustLoadConfig(), config: awsutils.MustLoadConfig(),
@ -35,7 +66,7 @@ type Training struct {
outputBucket string outputBucket string
} }
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
l.Infof("run training with data from %s", basedir) l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
} }
l.Info("") l.Info("")
err = t.runTraining( err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
ctx,
jobName,
sliceSize,
imgHeight,
imgWidth,
enableSpotTraining,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to run training: %w", err) return fmt.Errorf("unable to run training: %w", err)
} }
@ -95,7 +119,7 @@ func List(bucketName string) error {
return nil return nil
} }
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error { func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
l := zap.S() l := zap.S()
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig()) client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
l.Infof("Start training job '%s'", jobName) l.Infof("Start training job '%s'", jobName)
@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
"slide_size": strconv.Itoa(slideSize), "slide_size": strconv.Itoa(slideSize),
"img_height": strconv.Itoa(imgHeight), "img_height": strconv.Itoa(imgHeight),
"img_width": strconv.Itoa(imgWidth), "img_width": strconv.Itoa(imgWidth),
"batch_size": strconv.Itoa(32),
"model_type": modelType.String(),
"horizon": strconv.Itoa(horizon),
}, },
InputDataConfig: []types.Channel{ InputDataConfig: []types.Channel{
{ {