feat(training): new features
* add flip-image option * add command to list models * add option to override image size when training is launched * add option to disable aws spot instance
This commit is contained in:
		@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/dkimpt"
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/part"
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/pkg/data"
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/pkg/models"
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/pkg/train"
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/record"
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/video"
 | 
			
		||||
@@ -42,6 +43,7 @@ func main() {
 | 
			
		||||
		fmt.Printf("  display\n  \tDisplay events on live frames\n")
 | 
			
		||||
		fmt.Printf("  record \n  \tRecord event for tensorflow training\n")
 | 
			
		||||
		fmt.Printf("  training  \n  \tManage training\n")
 | 
			
		||||
		fmt.Printf("  models  \n  \tManage models\n")
 | 
			
		||||
		fmt.Printf("  import-donkey-records \n  \tCopy donkeycar records to new format\n")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -88,6 +90,9 @@ func main() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var modelPath, roleArn, trainJobName string
 | 
			
		||||
	var withFlipImage bool
 | 
			
		||||
	var trainImageHeight, trainImageWidth int
 | 
			
		||||
	var enableSpotTraining bool
 | 
			
		||||
	trainingRunFlags := flag.NewFlagSet("run", flag.ExitOnError)
 | 
			
		||||
	trainingRunFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
 | 
			
		||||
	trainingRunFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Input data path where records and img files are stored, use RECORD_PATH if arg not set")
 | 
			
		||||
@@ -96,14 +101,36 @@ func main() {
 | 
			
		||||
	trainingRunFlags.StringVar(&ociImage, "oci-image", os.Getenv("RC_TRAIN_OCI_IMAGE"), "OCI image to run (required), use RC_TRAIN_OCI_IMAGE if args not set")
 | 
			
		||||
	trainingRunFlags.StringVar(&roleArn, "role-arn", os.Getenv("RC_TRAIN_ROLE"), "AWS ARN role to use to run training (required), use RC_TRAIN_ROLE if arg not set")
 | 
			
		||||
	trainingRunFlags.StringVar(&trainJobName, "job-name", "", "Training job name (required)")
 | 
			
		||||
	trainingRunFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
 | 
			
		||||
 | 
			
		||||
	trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
 | 
			
		||||
	trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
 | 
			
		||||
 | 
			
		||||
	trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
 | 
			
		||||
	trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
 | 
			
		||||
 | 
			
		||||
	trainArchiveFlags := flag.NewFlagSet("archive", flag.ExitOnError)
 | 
			
		||||
	trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set")
 | 
			
		||||
	trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
 | 
			
		||||
	trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set")
 | 
			
		||||
	trainArchiveFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
	modelsFlags := flag.NewFlagSet("models", flag.ExitOnError)
 | 
			
		||||
	modelsFlags.Usage = func() {
 | 
			
		||||
		fmt.Printf("Usage of %s %s:\n", os.Args[0], modelsFlags.Name())
 | 
			
		||||
		fmt.Printf("  list\n  \tList existing models\n")
 | 
			
		||||
		fmt.Printf("  download\n  \tDownload existing models\n")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelsListFlags := flag.NewFlagSet("list", flag.ExitOnError)
 | 
			
		||||
	modelsListFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
 | 
			
		||||
 | 
			
		||||
	var modelPathBucket string
 | 
			
		||||
	modelsDownloadFlags := flag.NewFlagSet("download", flag.ExitOnError)
 | 
			
		||||
	modelsDownloadFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
 | 
			
		||||
	modelsDownloadFlags.StringVar(&modelPathBucket, "model", "", "S3 Model key into bucket (mandatory)")
 | 
			
		||||
	modelsDownloadFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
 | 
			
		||||
	config := zap.NewDevelopmentConfig()
 | 
			
		||||
@@ -173,18 +200,47 @@ func main() {
 | 
			
		||||
				trainingRunFlags.PrintDefaults()
 | 
			
		||||
				os.Exit(0)
 | 
			
		||||
			}
 | 
			
		||||
			runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, modelPath)
 | 
			
		||||
			runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, withFlipImage, modelPath,
 | 
			
		||||
				trainImageHeight, trainImageWidth, enableSpotTraining)
 | 
			
		||||
		case trainArchiveFlags.Name():
 | 
			
		||||
			if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
 | 
			
		||||
				trainArchiveFlags.PrintDefaults()
 | 
			
		||||
				os.Exit(0)
 | 
			
		||||
			}
 | 
			
		||||
			runTrainArchive(recordsPath, trainArchiveName, trainSliceSize)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
			runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, withFlipImage)
 | 
			
		||||
		default:
 | 
			
		||||
			trainingFlags.PrintDefaults()
 | 
			
		||||
			os.Exit(0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	case modelsFlags.Name():
 | 
			
		||||
 | 
			
		||||
		if err := modelsFlags.Parse(os.Args[2:]); err == flag.ErrHelp {
 | 
			
		||||
			modelsFlags.PrintDefaults()
 | 
			
		||||
			os.Exit(0)
 | 
			
		||||
		}
 | 
			
		||||
		switch modelsFlags.Arg(0) {
 | 
			
		||||
		case modelsListFlags.Name():
 | 
			
		||||
			if err:= modelsListFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
 | 
			
		||||
				modelsListFlags.PrintDefaults()
 | 
			
		||||
				os.Exit(0)
 | 
			
		||||
			}
 | 
			
		||||
			runModelsList(bucket)
 | 
			
		||||
		case modelsDownloadFlags.Name():
 | 
			
		||||
			if err:= modelsDownloadFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
 | 
			
		||||
				modelsDownloadFlags.PrintDefaults()
 | 
			
		||||
				os.Exit(0)
 | 
			
		||||
			}
 | 
			
		||||
			if trainArchiveName == "" {
 | 
			
		||||
				zap.S().Error("output model file is mandatory")
 | 
			
		||||
				modelsDownloadFlags.PrintDefaults()
 | 
			
		||||
				os.Exit(1)
 | 
			
		||||
			}
 | 
			
		||||
			runModelsDownload(bucket, modelPathBucket, trainArchiveName)
 | 
			
		||||
		default:
 | 
			
		||||
			modelsFlags.PrintDefaults()
 | 
			
		||||
			os.Exit(0)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
@@ -210,9 +266,9 @@ func runRecord(client mqtt.Client, recordsDir, recordTopic string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runTrainArchive(basedir, archiveName string, sliceSize int) {
 | 
			
		||||
func runTrainArchive(basedir, archiveName string, sliceSize int, withFlipImage bool) {
 | 
			
		||||
 | 
			
		||||
	err := data.WriteArchive(basedir, archiveName, sliceSize)
 | 
			
		||||
	err := data.WriteArchive(basedir, archiveName, sliceSize, withFlipImage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Fatalf("unable to build archive file %v: %v", archiveName, err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -254,7 +310,9 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, outputModel string) {
 | 
			
		||||
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, withFlipImage bool,
 | 
			
		||||
	outputModel string, imgHeight int, imgWidth int, enableSpotTraining bool) {
 | 
			
		||||
 | 
			
		||||
	l := zap.S()
 | 
			
		||||
	if bucketName == "" {
 | 
			
		||||
		l.Fatalf("no bucket define, see help")
 | 
			
		||||
@@ -277,7 +335,8 @@ func runTraining(bucketName string, ociImage string, roleArn string, jobName, da
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	training := train.New(bucketName, ociImage, roleArn)
 | 
			
		||||
	err := training.TrainDir(context.Background(), jobName, dataDir, sliceSize, outputModel)
 | 
			
		||||
	err := training.TrainDir(context.Background(), jobName, dataDir, imgHeight, imgWidth, sliceSize, withFlipImage,
 | 
			
		||||
		outputModel, enableSpotTraining)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		l.Fatalf("unable to run training: %v", err)
 | 
			
		||||
@@ -289,4 +348,19 @@ func runTrainList() {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Fatalf("unable to list training jobs: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runModelsList(bucketName string) {
 | 
			
		||||
	err := models.ListModels(context.Background(), bucketName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Fatalf("unable to list models: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runModelsDownload(bucketName, modelPath, output string) {
 | 
			
		||||
	err := models.DownloadArchiveToFile(context.Background(), bucketName, modelPath, output)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Fatalf("unable to download model: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user