feat(training): new features

* add flip-image option
 * add command to list models
 * add option to override image size when training is launched
 * add option to disable aws spot instance
This commit is contained in:
2021-11-24 19:31:16 +01:00
parent c69e1c20ef
commit 456f327788
37 changed files with 7232 additions and 93 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/cyrilix/robocar-tools/dkimpt"
"github.com/cyrilix/robocar-tools/part"
"github.com/cyrilix/robocar-tools/pkg/data"
"github.com/cyrilix/robocar-tools/pkg/models"
"github.com/cyrilix/robocar-tools/pkg/train"
"github.com/cyrilix/robocar-tools/record"
"github.com/cyrilix/robocar-tools/video"
@ -42,6 +43,7 @@ func main() {
fmt.Printf(" display\n \tDisplay events on live frames\n")
fmt.Printf(" record \n \tRecord event for tensorflow training\n")
fmt.Printf(" training \n \tManage training\n")
fmt.Printf(" models \n \tManage models\n")
fmt.Printf(" import-donkey-records \n \tCopy donkeycar records to new format\n")
}
@ -88,6 +90,9 @@ func main() {
}
var modelPath, roleArn, trainJobName string
var withFlipImage bool
var trainImageHeight, trainImageWidth int
var enableSpotTraining bool
trainingRunFlags := flag.NewFlagSet("run", flag.ExitOnError)
trainingRunFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
trainingRunFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Input data path where records and img files are stored, use RECORD_PATH if arg not set")
@ -96,14 +101,36 @@ func main() {
trainingRunFlags.StringVar(&ociImage, "oci-image", os.Getenv("RC_TRAIN_OCI_IMAGE"), "OCI image to run (required), use RC_TRAIN_OCI_IMAGE if args not set")
trainingRunFlags.StringVar(&roleArn, "role-arn", os.Getenv("RC_TRAIN_ROLE"), "AWS ARN role to use to run training (required), use RC_TRAIN_ROLE if arg not set")
trainingRunFlags.StringVar(&trainJobName, "job-name", "", "Training job name (required)")
trainingRunFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
trainArchiveFlags := flag.NewFlagSet("archive", flag.ExitOnError)
trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set")
trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set")
trainArchiveFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
modelsFlags := flag.NewFlagSet("models", flag.ExitOnError)
modelsFlags.Usage = func() {
fmt.Printf("Usage of %s %s:\n", os.Args[0], modelsFlags.Name())
fmt.Printf(" list\n \tList existing models\n")
fmt.Printf(" download\n \tDownload existing models\n")
}
modelsListFlags := flag.NewFlagSet("list", flag.ExitOnError)
modelsListFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
var modelPathBucket string
modelsDownloadFlags := flag.NewFlagSet("download", flag.ExitOnError)
modelsDownloadFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
modelsDownloadFlags.StringVar(&modelPathBucket, "model", "", "S3 Model key into bucket (mandatory)")
modelsDownloadFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
flag.Parse()
config := zap.NewDevelopmentConfig()
@ -173,18 +200,47 @@ func main() {
trainingRunFlags.PrintDefaults()
os.Exit(0)
}
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, modelPath)
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, withFlipImage, modelPath,
trainImageHeight, trainImageWidth, enableSpotTraining)
case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults()
os.Exit(0)
}
runTrainArchive(recordsPath, trainArchiveName, trainSliceSize)
runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, withFlipImage)
default:
trainingFlags.PrintDefaults()
os.Exit(0)
}
case modelsFlags.Name():
if err := modelsFlags.Parse(os.Args[2:]); err == flag.ErrHelp {
modelsFlags.PrintDefaults()
os.Exit(0)
}
switch modelsFlags.Arg(0) {
case modelsListFlags.Name():
if err:= modelsListFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
modelsListFlags.PrintDefaults()
os.Exit(0)
}
runModelsList(bucket)
case modelsDownloadFlags.Name():
if err:= modelsDownloadFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
modelsDownloadFlags.PrintDefaults()
os.Exit(0)
}
if trainArchiveName == "" {
zap.S().Error("output model file is mandatory")
modelsDownloadFlags.PrintDefaults()
os.Exit(1)
}
runModelsDownload(bucket, modelPathBucket, trainArchiveName)
default:
modelsFlags.PrintDefaults()
os.Exit(0)
}
default:
@ -210,9 +266,9 @@ func runRecord(client mqtt.Client, recordsDir, recordTopic string) {
}
}
func runTrainArchive(basedir, archiveName string, sliceSize int) {
func runTrainArchive(basedir, archiveName string, sliceSize int, withFlipImage bool) {
err := data.WriteArchive(basedir, archiveName, sliceSize)
err := data.WriteArchive(basedir, archiveName, sliceSize, withFlipImage)
if err != nil {
zap.S().Fatalf("unable to build archive file %v: %v", archiveName, err)
}
@ -254,7 +310,9 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
}
}
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, outputModel string) {
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, withFlipImage bool,
outputModel string, imgHeight int, imgWidth int, enableSpotTraining bool) {
l := zap.S()
if bucketName == "" {
l.Fatalf("no bucket define, see help")
@ -277,7 +335,8 @@ func runTraining(bucketName string, ociImage string, roleArn string, jobName, da
}
training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, sliceSize, outputModel)
err := training.TrainDir(context.Background(), jobName, dataDir, imgHeight, imgWidth, sliceSize, withFlipImage,
outputModel, enableSpotTraining)
if err != nil {
l.Fatalf("unable to run training: %v", err)
@ -289,4 +348,19 @@ func runTrainList() {
if err != nil {
zap.S().Fatalf("unable to list training jobs: %w", err)
}
}
func runModelsList(bucketName string) {
err := models.ListModels(context.Background(), bucketName)
if err != nil {
zap.S().Fatalf("unable to list models: %s", err)
}
}
func runModelsDownload(bucketName, modelPath, output string) {
err := models.DownloadArchiveToFile(context.Background(), bucketName, modelPath, output)
if err != nil {
zap.S().Fatalf("unable to download model: %s", err)
}
}