feat(training): add flag toconfigure model type training
This commit is contained in:
@ -102,7 +102,7 @@ func main() {
|
||||
fmt.Printf(" run\n \tRun training job\n")
|
||||
}
|
||||
|
||||
var modelPath, roleArn, trainJobName string
|
||||
var modelPath, roleArn, trainJobName, modelType string
|
||||
var horizon int
|
||||
var withFlipImage bool
|
||||
var trainImageHeight, trainImageWidth int
|
||||
@ -120,6 +120,7 @@ func main() {
|
||||
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
|
||||
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
|
||||
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
|
||||
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
|
||||
|
||||
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
|
||||
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
|
||||
@ -237,7 +238,7 @@ func main() {
|
||||
trainingRunFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
|
||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
|
||||
case trainArchiveFlags.Name():
|
||||
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
||||
trainArchiveFlags.PrintDefaults()
|
||||
@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
|
||||
}
|
||||
}
|
||||
|
||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
|
||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
|
||||
|
||||
l := zap.S()
|
||||
if bucketName == "" {
|
||||
@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
|
||||
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
|
||||
}
|
||||
|
||||
if modelType == train.ModelTypeUnknown {
|
||||
l.Fatalf("invalid model type: %v", modelType)
|
||||
}
|
||||
|
||||
training := train.New(bucketName, ociImage, roleArn)
|
||||
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
|
||||
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
|
||||
|
||||
if err != nil {
|
||||
l.Fatalf("unable to run training: %v", err)
|
||||
|
Reference in New Issue
Block a user