feat(training): add flag toconfigure model type training
This commit is contained in:
parent
b7b4bd76a0
commit
c4ff4b46b0
@ -102,7 +102,7 @@ func main() {
|
||||
fmt.Printf(" run\n \tRun training job\n")
|
||||
}
|
||||
|
||||
var modelPath, roleArn, trainJobName string
|
||||
var modelPath, roleArn, trainJobName, modelType string
|
||||
var horizon int
|
||||
var withFlipImage bool
|
||||
var trainImageHeight, trainImageWidth int
|
||||
@ -120,6 +120,7 @@ func main() {
|
||||
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
|
||||
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
|
||||
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
|
||||
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
|
||||
|
||||
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
|
||||
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
|
||||
@ -237,7 +238,7 @@ func main() {
|
||||
trainingRunFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
|
||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
|
||||
case trainArchiveFlags.Name():
|
||||
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
||||
trainArchiveFlags.PrintDefaults()
|
||||
@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
|
||||
}
|
||||
}
|
||||
|
||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
|
||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
|
||||
|
||||
l := zap.S()
|
||||
if bucketName == "" {
|
||||
@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
|
||||
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
|
||||
}
|
||||
|
||||
if modelType == train.ModelTypeUnknown {
|
||||
l.Fatalf("invalid model type: %v", modelType)
|
||||
}
|
||||
|
||||
training := train.New(bucketName, ociImage, roleArn)
|
||||
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
|
||||
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
|
||||
|
||||
if err != nil {
|
||||
l.Fatalf("unable to run training: %v", err)
|
||||
|
@ -12,9 +12,40 @@ import (
|
||||
"github.com/cyrilix/robocar-tools/pkg/models"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ModelType int
|
||||
|
||||
func ParseModelType(s string) ModelType {
|
||||
switch strings.ToLower(s) {
|
||||
case "categorical":
|
||||
return ModelTypeCategorical
|
||||
case "linear":
|
||||
return ModelTypeLinear
|
||||
default:
|
||||
return ModelTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (m ModelType) String() string {
|
||||
switch m {
|
||||
case ModelTypeCategorical:
|
||||
return "categorical"
|
||||
case ModelTypeLinear:
|
||||
return "linear"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ModelTypeUnknown ModelType = iota
|
||||
ModelTypeCategorical
|
||||
ModelTypeLinear
|
||||
)
|
||||
|
||||
func New(bucketName string, ociImage, roleArn string) *Training {
|
||||
return &Training{
|
||||
config: awsutils.MustLoadConfig(),
|
||||
@ -35,7 +66,7 @@ type Training struct {
|
||||
outputBucket string
|
||||
}
|
||||
|
||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
|
||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
|
||||
l := zap.S()
|
||||
l.Infof("run training with data from %s", basedir)
|
||||
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
|
||||
@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
|
||||
}
|
||||
l.Info("")
|
||||
|
||||
err = t.runTraining(
|
||||
ctx,
|
||||
jobName,
|
||||
sliceSize,
|
||||
imgHeight,
|
||||
imgWidth,
|
||||
enableSpotTraining,
|
||||
)
|
||||
err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to run training: %w", err)
|
||||
}
|
||||
@ -95,7 +119,7 @@ func List(bucketName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
|
||||
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
|
||||
l := zap.S()
|
||||
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
|
||||
l.Infof("Start training job '%s'", jobName)
|
||||
@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
|
||||
"slide_size": strconv.Itoa(slideSize),
|
||||
"img_height": strconv.Itoa(imgHeight),
|
||||
"img_width": strconv.Itoa(imgWidth),
|
||||
"batch_size": strconv.Itoa(32),
|
||||
"model_type": modelType.String(),
|
||||
"horizon": strconv.Itoa(horizon),
|
||||
},
|
||||
InputDataConfig: []types.Channel{
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user