Compare commits

...

2 Commits

Author SHA1 Message Date
Cyrille Nofficial
b0e91a5994 feat(training): bash script to run training 2022-10-06 14:09:41 +02:00
Cyrille Nofficial
c4ff4b46b0 feat(training): add flag toconfigure model type training 2022-10-06 14:07:44 +02:00
3 changed files with 55 additions and 22 deletions

View File

@ -102,7 +102,7 @@ func main() {
fmt.Printf(" run\n \tRun training job\n") fmt.Printf(" run\n \tRun training job\n")
} }
var modelPath, roleArn, trainJobName string var modelPath, roleArn, trainJobName, modelType string
var horizon int var horizon int
var withFlipImage bool var withFlipImage bool
var trainImageHeight, trainImageWidth int var trainImageHeight, trainImageWidth int
@ -120,6 +120,7 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -237,7 +238,7 @@ func main() {
trainingRunFlags.PrintDefaults() trainingRunFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
case trainArchiveFlags.Name(): case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults() trainArchiveFlags.PrintDefaults()
@ -355,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
} }
} }
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
l := zap.S() l := zap.S()
if bucketName == "" { if bucketName == "" {
@ -378,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed") l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
} }
if modelType == train.ModelTypeUnknown {
l.Fatalf("invalid model type: %v", modelType)
}
training := train.New(bucketName, ociImage, roleArn) training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
if err != nil { if err != nil {
l.Fatalf("unable to run training: %v", err) l.Fatalf("unable to run training: %v", err)

View File

@ -12,9 +12,40 @@ import (
"github.com/cyrilix/robocar-tools/pkg/models" "github.com/cyrilix/robocar-tools/pkg/models"
"go.uber.org/zap" "go.uber.org/zap"
"strconv" "strconv"
"strings"
"time" "time"
) )
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
func New(bucketName string, ociImage, roleArn string) *Training { func New(bucketName string, ociImage, roleArn string) *Training {
return &Training{ return &Training{
config: awsutils.MustLoadConfig(), config: awsutils.MustLoadConfig(),
@ -35,7 +66,7 @@ type Training struct {
outputBucket string outputBucket string
} }
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
l.Infof("run training with data from %s", basedir) l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
} }
l.Info("") l.Info("")
err = t.runTraining( err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
ctx,
jobName,
sliceSize,
imgHeight,
imgWidth,
enableSpotTraining,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to run training: %w", err) return fmt.Errorf("unable to run training: %w", err)
} }
@ -95,7 +119,7 @@ func List(bucketName string) error {
return nil return nil
} }
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error { func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
l := zap.S() l := zap.S()
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig()) client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
l.Infof("Start training job '%s'", jobName) l.Infof("Start training job '%s'", jobName)
@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
"slide_size": strconv.Itoa(slideSize), "slide_size": strconv.Itoa(slideSize),
"img_height": strconv.Itoa(imgHeight), "img_height": strconv.Itoa(imgHeight),
"img_width": strconv.Itoa(imgWidth), "img_width": strconv.Itoa(imgWidth),
"batch_size": strconv.Itoa(32),
"model_type": modelType.String(),
"horizon": strconv.Itoa(horizon),
}, },
InputDataConfig: []types.Channel{ InputDataConfig: []types.Channel{
{ {

View File

@ -3,8 +3,8 @@
set +e set +e
set +x set +x
RECORDS_PATH=~/robocar/record-sim4-2 RECORDS_PATH=~/src/robocars/data/viva20/viva12/
#TRAINING_OPTS="--horizon=20" #TRAINING_OPTS="--horizon=50"
TRAINING_OPTS="" TRAINING_OPTS=""
MODEL_TYPE="categorical" MODEL_TYPE="categorical"
#MODEL_TYPE="linear" #MODEL_TYPE="linear"
@ -12,8 +12,9 @@ IMG_WIDTH=160
IMG_HEIGHT=120 IMG_HEIGHT=120
HORIZON=20 HORIZON=20
TRAINING_DATA_DIR=/tmp/data TRAINING_DIR=~/src/robocars/trainings
TRAINING_OUTPUT_DIR=/tmp/output TRAINING_DATA_DIR=${TRAINING_DIR}/data
TRAINING_OUTPUT_DIR=${TRAINING_DIR}/output
TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip
####################### #######################
@ -30,10 +31,10 @@ go run ./cmd/rc-tools training archive \
-image-width ${IMG_WIDTH} -image-width ${IMG_WIDTH}
printf "\n\nRun training\n\n" printf "\n\nRun training\n\n"
podman run --rm -it \ podman run --rm -it \
-v /tmp/data:/opt/ml/input/data/train \ -v /trainings/data:/opt/ml/input/data/train \
-v /tmp/output:/opt/ml/model/ \ -v /trainings/output:/opt/ml/model/ \
localhost/tensorflow_without_gpu \ localhost/tensorflow_without_gpu:old \
python /opt/ml/code/train.py \ python /opt/ml/code/train.py \
--img_height=${IMG_HEIGHT} \ --img_height=${IMG_HEIGHT} \
--img_width=${IMG_WIDTH} \ --img_width=${IMG_WIDTH} \