feat(training): add flag toconfigure model type training
This commit is contained in:
		@@ -102,7 +102,7 @@ func main() {
 | 
			
		||||
		fmt.Printf("  run\n  \tRun training job\n")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var modelPath, roleArn, trainJobName string
 | 
			
		||||
	var modelPath, roleArn, trainJobName, modelType string
 | 
			
		||||
	var horizon int
 | 
			
		||||
	var withFlipImage bool
 | 
			
		||||
	var trainImageHeight, trainImageWidth int
 | 
			
		||||
@@ -120,6 +120,7 @@ func main() {
 | 
			
		||||
	trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
 | 
			
		||||
	trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
 | 
			
		||||
	trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
 | 
			
		||||
	trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
 | 
			
		||||
 | 
			
		||||
	trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
 | 
			
		||||
	trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
 | 
			
		||||
@@ -237,7 +238,7 @@ func main() {
 | 
			
		||||
				trainingRunFlags.PrintDefaults()
 | 
			
		||||
				os.Exit(0)
 | 
			
		||||
			}
 | 
			
		||||
			runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
 | 
			
		||||
			runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
 | 
			
		||||
		case trainArchiveFlags.Name():
 | 
			
		||||
			if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
 | 
			
		||||
				trainArchiveFlags.PrintDefaults()
 | 
			
		||||
@@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
 | 
			
		||||
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
 | 
			
		||||
 | 
			
		||||
	l := zap.S()
 | 
			
		||||
	if bucketName == "" {
 | 
			
		||||
@@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
 | 
			
		||||
		l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if modelType == train.ModelTypeUnknown {
 | 
			
		||||
		l.Fatalf("invalid model type: %v", modelType)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	training := train.New(bucketName, ociImage, roleArn)
 | 
			
		||||
	err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
 | 
			
		||||
	err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		l.Fatalf("unable to run training: %v", err)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user