feat(training): add flag toconfigure model type training

This commit is contained in:
Cyrille Nofficial 2022-10-06 14:05:07 +02:00
parent b7b4bd76a0
commit c4ff4b46b0
2 changed files with 46 additions and 14 deletions

View File

@ -102,7 +102,7 @@ func main() {
fmt.Printf(" run\n \tRun training job\n")
}
var modelPath, roleArn, trainJobName string
var modelPath, roleArn, trainJobName, modelType string
var horizon int
var withFlipImage bool
var trainImageHeight, trainImageWidth int
@ -120,6 +120,7 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -237,7 +238,7 @@ func main() {
trainingRunFlags.PrintDefaults()
os.Exit(0)
}
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults()
@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
}
}
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
l := zap.S()
if bucketName == "" {
@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
}
if modelType == train.ModelTypeUnknown {
l.Fatalf("invalid model type: %v", modelType)
}
training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
if err != nil {
l.Fatalf("unable to run training: %v", err)

View File

@ -12,9 +12,40 @@ import (
"github.com/cyrilix/robocar-tools/pkg/models"
"go.uber.org/zap"
"strconv"
"strings"
"time"
)
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
func New(bucketName string, ociImage, roleArn string) *Training {
return &Training{
config: awsutils.MustLoadConfig(),
@ -35,7 +66,7 @@ type Training struct {
outputBucket string
}
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S()
l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
}
l.Info("")
err = t.runTraining(
ctx,
jobName,
sliceSize,
imgHeight,
imgWidth,
enableSpotTraining,
)
err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
if err != nil {
return fmt.Errorf("unable to run training: %w", err)
}
@ -95,7 +119,7 @@ func List(bucketName string) error {
return nil
}
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
l := zap.S()
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
l.Infof("Start training job '%s'", jobName)
@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
"slide_size": strconv.Itoa(slideSize),
"img_height": strconv.Itoa(imgHeight),
"img_width": strconv.Itoa(imgWidth),
"batch_size": strconv.Itoa(32),
"model_type": modelType.String(),
"horizon": strconv.Itoa(horizon),
},
InputDataConfig: []types.Channel{
{