feat(training): new features
* add flip-image option * add command to list models * add option to override image size when training is launched * add option to disable aws spot instance
This commit is contained in:
@ -8,6 +8,7 @@ import (
|
||||
"github.com/cyrilix/robocar-tools/dkimpt"
|
||||
"github.com/cyrilix/robocar-tools/part"
|
||||
"github.com/cyrilix/robocar-tools/pkg/data"
|
||||
"github.com/cyrilix/robocar-tools/pkg/models"
|
||||
"github.com/cyrilix/robocar-tools/pkg/train"
|
||||
"github.com/cyrilix/robocar-tools/record"
|
||||
"github.com/cyrilix/robocar-tools/video"
|
||||
@ -42,6 +43,7 @@ func main() {
|
||||
fmt.Printf(" display\n \tDisplay events on live frames\n")
|
||||
fmt.Printf(" record \n \tRecord event for tensorflow training\n")
|
||||
fmt.Printf(" training \n \tManage training\n")
|
||||
fmt.Printf(" models \n \tManage models\n")
|
||||
fmt.Printf(" import-donkey-records \n \tCopy donkeycar records to new format\n")
|
||||
}
|
||||
|
||||
@ -88,6 +90,9 @@ func main() {
|
||||
}
|
||||
|
||||
var modelPath, roleArn, trainJobName string
|
||||
var withFlipImage bool
|
||||
var trainImageHeight, trainImageWidth int
|
||||
var enableSpotTraining bool
|
||||
trainingRunFlags := flag.NewFlagSet("run", flag.ExitOnError)
|
||||
trainingRunFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
|
||||
trainingRunFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Input data path where records and img files are stored, use RECORD_PATH if arg not set")
|
||||
@ -96,14 +101,36 @@ func main() {
|
||||
trainingRunFlags.StringVar(&ociImage, "oci-image", os.Getenv("RC_TRAIN_OCI_IMAGE"), "OCI image to run (required), use RC_TRAIN_OCI_IMAGE if args not set")
|
||||
trainingRunFlags.StringVar(&roleArn, "role-arn", os.Getenv("RC_TRAIN_ROLE"), "AWS ARN role to use to run training (required), use RC_TRAIN_ROLE if arg not set")
|
||||
trainingRunFlags.StringVar(&trainJobName, "job-name", "", "Training job name (required)")
|
||||
trainingRunFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
|
||||
|
||||
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
|
||||
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
|
||||
|
||||
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
|
||||
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
|
||||
|
||||
trainArchiveFlags := flag.NewFlagSet("archive", flag.ExitOnError)
|
||||
trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set")
|
||||
trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
|
||||
trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set")
|
||||
trainArchiveFlags.BoolVar(&withFlipImage, "with-flip-image", withFlipImage, "Flip horiontal image and reverse steering to increase data into training archive")
|
||||
|
||||
|
||||
modelsFlags := flag.NewFlagSet("models", flag.ExitOnError)
|
||||
modelsFlags.Usage = func() {
|
||||
fmt.Printf("Usage of %s %s:\n", os.Args[0], modelsFlags.Name())
|
||||
fmt.Printf(" list\n \tList existing models\n")
|
||||
fmt.Printf(" download\n \tDownload existing models\n")
|
||||
}
|
||||
|
||||
modelsListFlags := flag.NewFlagSet("list", flag.ExitOnError)
|
||||
modelsListFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
|
||||
|
||||
var modelPathBucket string
|
||||
modelsDownloadFlags := flag.NewFlagSet("download", flag.ExitOnError)
|
||||
modelsDownloadFlags.StringVar(&bucket, "bucket", os.Getenv("RC_TRAIN_BUCKET"), "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set")
|
||||
modelsDownloadFlags.StringVar(&modelPathBucket, "model", "", "S3 Model key into bucket (mandatory)")
|
||||
modelsDownloadFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
|
||||
flag.Parse()
|
||||
|
||||
config := zap.NewDevelopmentConfig()
|
||||
@ -173,18 +200,47 @@ func main() {
|
||||
trainingRunFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, modelPath)
|
||||
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, withFlipImage, modelPath,
|
||||
trainImageHeight, trainImageWidth, enableSpotTraining)
|
||||
case trainArchiveFlags.Name():
|
||||
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
||||
trainArchiveFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
runTrainArchive(recordsPath, trainArchiveName, trainSliceSize)
|
||||
|
||||
|
||||
runTrainArchive(recordsPath, trainArchiveName, trainSliceSize, withFlipImage)
|
||||
default:
|
||||
trainingFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
|
||||
|
||||
}
|
||||
case modelsFlags.Name():
|
||||
|
||||
if err := modelsFlags.Parse(os.Args[2:]); err == flag.ErrHelp {
|
||||
modelsFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
switch modelsFlags.Arg(0) {
|
||||
case modelsListFlags.Name():
|
||||
if err:= modelsListFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
||||
modelsListFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
runModelsList(bucket)
|
||||
case modelsDownloadFlags.Name():
|
||||
if err:= modelsDownloadFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
|
||||
modelsDownloadFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
if trainArchiveName == "" {
|
||||
zap.S().Error("output model file is mandatory")
|
||||
modelsDownloadFlags.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
runModelsDownload(bucket, modelPathBucket, trainArchiveName)
|
||||
default:
|
||||
modelsFlags.PrintDefaults()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
default:
|
||||
@ -210,9 +266,9 @@ func runRecord(client mqtt.Client, recordsDir, recordTopic string) {
|
||||
}
|
||||
}
|
||||
|
||||
func runTrainArchive(basedir, archiveName string, sliceSize int) {
|
||||
func runTrainArchive(basedir, archiveName string, sliceSize int, withFlipImage bool) {
|
||||
|
||||
err := data.WriteArchive(basedir, archiveName, sliceSize)
|
||||
err := data.WriteArchive(basedir, archiveName, sliceSize, withFlipImage)
|
||||
if err != nil {
|
||||
zap.S().Fatalf("unable to build archive file %v: %v", archiveName, err)
|
||||
}
|
||||
@ -254,7 +310,9 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
|
||||
}
|
||||
}
|
||||
|
||||
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, outputModel string) {
|
||||
func runTraining(bucketName string, ociImage string, roleArn string, jobName, dataDir string, sliceSize int, withFlipImage bool,
|
||||
outputModel string, imgHeight int, imgWidth int, enableSpotTraining bool) {
|
||||
|
||||
l := zap.S()
|
||||
if bucketName == "" {
|
||||
l.Fatalf("no bucket define, see help")
|
||||
@ -277,7 +335,8 @@ func runTraining(bucketName string, ociImage string, roleArn string, jobName, da
|
||||
}
|
||||
|
||||
training := train.New(bucketName, ociImage, roleArn)
|
||||
err := training.TrainDir(context.Background(), jobName, dataDir, sliceSize, outputModel)
|
||||
err := training.TrainDir(context.Background(), jobName, dataDir, imgHeight, imgWidth, sliceSize, withFlipImage,
|
||||
outputModel, enableSpotTraining)
|
||||
|
||||
if err != nil {
|
||||
l.Fatalf("unable to run training: %v", err)
|
||||
@ -289,4 +348,19 @@ func runTrainList() {
|
||||
if err != nil {
|
||||
zap.S().Fatalf("unable to list training jobs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runModelsList(bucketName string) {
|
||||
err := models.ListModels(context.Background(), bucketName)
|
||||
if err != nil {
|
||||
zap.S().Fatalf("unable to list models: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func runModelsDownload(bucketName, modelPath, output string) {
|
||||
err := models.DownloadArchiveToFile(context.Background(), bucketName, modelPath, output)
|
||||
if err != nil {
|
||||
zap.S().Fatalf("unable to download model: %s", err)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user