feat(training): add flag toconfigure model type training
This commit is contained in:
parent
b7b4bd76a0
commit
c4ff4b46b0
@ -102,7 +102,7 @@ func main() {
|
|||||||
fmt.Printf(" run\n \tRun training job\n")
|
fmt.Printf(" run\n \tRun training job\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
var modelPath, roleArn, trainJobName string
|
var modelPath, roleArn, trainJobName, modelType string
|
||||||
var horizon int
|
var horizon int
|
||||||
var withFlipImage bool
|
var withFlipImage bool
|
||||||
var trainImageHeight, trainImageWidth int
|
var trainImageHeight, trainImageWidth int
|
||||||
@ -120,6 +120,7 @@ func main() {
|
|||||||
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
|
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
|
||||||
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
|
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
|
||||||
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
|
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
|
||||||
|
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
|
||||||
|
|
||||||
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
|
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
|
||||||
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
|
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
|
||||||
@ -237,7 +238,7 @@ func main() {
|
|||||||
trainingRunFlags.PrintDefaults()
|
trainingRunFlags.PrintDefaults()
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
|
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
|
||||||
case trainArchiveFlags.Name():
|
case trainArchiveFlags.Name():
|
||||||
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
||||||
trainArchiveFlags.PrintDefaults()
|
trainArchiveFlags.PrintDefaults()
|
||||||
@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
|
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
|
||||||
|
|
||||||
l := zap.S()
|
l := zap.S()
|
||||||
if bucketName == "" {
|
if bucketName == "" {
|
||||||
@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
|
|||||||
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
|
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelType == train.ModelTypeUnknown {
|
||||||
|
l.Fatalf("invalid model type: %v", modelType)
|
||||||
|
}
|
||||||
|
|
||||||
training := train.New(bucketName, ociImage, roleArn)
|
training := train.New(bucketName, ociImage, roleArn)
|
||||||
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
|
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Fatalf("unable to run training: %v", err)
|
l.Fatalf("unable to run training: %v", err)
|
||||||
|
@ -12,9 +12,40 @@ import (
|
|||||||
"github.com/cyrilix/robocar-tools/pkg/models"
|
"github.com/cyrilix/robocar-tools/pkg/models"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ModelType int
|
||||||
|
|
||||||
|
func ParseModelType(s string) ModelType {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "categorical":
|
||||||
|
return ModelTypeCategorical
|
||||||
|
case "linear":
|
||||||
|
return ModelTypeLinear
|
||||||
|
default:
|
||||||
|
return ModelTypeUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ModelType) String() string {
|
||||||
|
switch m {
|
||||||
|
case ModelTypeCategorical:
|
||||||
|
return "categorical"
|
||||||
|
case ModelTypeLinear:
|
||||||
|
return "linear"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelTypeUnknown ModelType = iota
|
||||||
|
ModelTypeCategorical
|
||||||
|
ModelTypeLinear
|
||||||
|
)
|
||||||
|
|
||||||
func New(bucketName string, ociImage, roleArn string) *Training {
|
func New(bucketName string, ociImage, roleArn string) *Training {
|
||||||
return &Training{
|
return &Training{
|
||||||
config: awsutils.MustLoadConfig(),
|
config: awsutils.MustLoadConfig(),
|
||||||
@ -35,7 +66,7 @@ type Training struct {
|
|||||||
outputBucket string
|
outputBucket string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
|
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
|
||||||
l := zap.S()
|
l := zap.S()
|
||||||
l.Infof("run training with data from %s", basedir)
|
l.Infof("run training with data from %s", basedir)
|
||||||
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
|
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
|
||||||
@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
|
|||||||
}
|
}
|
||||||
l.Info("")
|
l.Info("")
|
||||||
|
|
||||||
err = t.runTraining(
|
err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
|
||||||
ctx,
|
|
||||||
jobName,
|
|
||||||
sliceSize,
|
|
||||||
imgHeight,
|
|
||||||
imgWidth,
|
|
||||||
enableSpotTraining,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to run training: %w", err)
|
return fmt.Errorf("unable to run training: %w", err)
|
||||||
}
|
}
|
||||||
@ -95,7 +119,7 @@ func List(bucketName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
|
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
|
||||||
l := zap.S()
|
l := zap.S()
|
||||||
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
|
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
|
||||||
l.Infof("Start training job '%s'", jobName)
|
l.Infof("Start training job '%s'", jobName)
|
||||||
@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
|
|||||||
"slide_size": strconv.Itoa(slideSize),
|
"slide_size": strconv.Itoa(slideSize),
|
||||||
"img_height": strconv.Itoa(imgHeight),
|
"img_height": strconv.Itoa(imgHeight),
|
||||||
"img_width": strconv.Itoa(imgWidth),
|
"img_width": strconv.Itoa(imgWidth),
|
||||||
|
"batch_size": strconv.Itoa(32),
|
||||||
|
"model_type": modelType.String(),
|
||||||
|
"horizon": strconv.Itoa(horizon),
|
||||||
},
|
},
|
||||||
InputDataConfig: []types.Channel{
|
InputDataConfig: []types.Channel{
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user