Compare commits

..

No commits in common. "b0e91a5994699217cfc44f44d5564dc5e100cb0e" and "b7b4bd76a0c48e4855cc1a9dbe810ea2c77b72d6" have entirely different histories.

3 changed files with 22 additions and 55 deletions

View File

@ -102,7 +102,7 @@ func main() {
fmt.Printf(" run\n \tRun training job\n") fmt.Printf(" run\n \tRun training job\n")
} }
var modelPath, roleArn, trainJobName, modelType string var modelPath, roleArn, trainJobName string
var horizon int var horizon int
var withFlipImage bool var withFlipImage bool
var trainImageHeight, trainImageWidth int var trainImageHeight, trainImageWidth int
@ -120,7 +120,6 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -238,7 +237,7 @@ func main() {
trainingRunFlags.PrintDefaults() trainingRunFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
case trainArchiveFlags.Name(): case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults() trainArchiveFlags.PrintDefaults()
@ -356,7 +355,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
} }
} }
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
l := zap.S() l := zap.S()
if bucketName == "" { if bucketName == "" {
@ -379,12 +378,8 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelTy
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed") l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
} }
if modelType == train.ModelTypeUnknown {
l.Fatalf("invalid model type: %v", modelType)
}
training := train.New(bucketName, ociImage, roleArn) training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
if err != nil { if err != nil {
l.Fatalf("unable to run training: %v", err) l.Fatalf("unable to run training: %v", err)

View File

@ -12,40 +12,9 @@ import (
"github.com/cyrilix/robocar-tools/pkg/models" "github.com/cyrilix/robocar-tools/pkg/models"
"go.uber.org/zap" "go.uber.org/zap"
"strconv" "strconv"
"strings"
"time" "time"
) )
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
func New(bucketName string, ociImage, roleArn string) *Training { func New(bucketName string, ociImage, roleArn string) *Training {
return &Training{ return &Training{
config: awsutils.MustLoadConfig(), config: awsutils.MustLoadConfig(),
@ -66,7 +35,7 @@ type Training struct {
outputBucket string outputBucket string
} }
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
l.Infof("run training with data from %s", basedir) l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
@ -81,7 +50,14 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelT
} }
l.Info("") l.Info("")
err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType) err = t.runTraining(
ctx,
jobName,
sliceSize,
imgHeight,
imgWidth,
enableSpotTraining,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to run training: %w", err) return fmt.Errorf("unable to run training: %w", err)
} }
@ -119,7 +95,7 @@ func List(bucketName string) error {
return nil return nil
} }
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error { func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig()) client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
l.Infof("Start training job '%s'", jobName) l.Infof("Start training job '%s'", jobName)
@ -149,9 +125,6 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, i
"slide_size": strconv.Itoa(slideSize), "slide_size": strconv.Itoa(slideSize),
"img_height": strconv.Itoa(imgHeight), "img_height": strconv.Itoa(imgHeight),
"img_width": strconv.Itoa(imgWidth), "img_width": strconv.Itoa(imgWidth),
"batch_size": strconv.Itoa(32),
"model_type": modelType.String(),
"horizon": strconv.Itoa(horizon),
}, },
InputDataConfig: []types.Channel{ InputDataConfig: []types.Channel{
{ {

View File

@ -3,8 +3,8 @@
set +e set +e
set +x set +x
RECORDS_PATH=~/src/robocars/data/viva20/viva12/ RECORDS_PATH=~/robocar/record-sim4-2
#TRAINING_OPTS="--horizon=50" #TRAINING_OPTS="--horizon=20"
TRAINING_OPTS="" TRAINING_OPTS=""
MODEL_TYPE="categorical" MODEL_TYPE="categorical"
#MODEL_TYPE="linear" #MODEL_TYPE="linear"
@ -12,9 +12,8 @@ IMG_WIDTH=160
IMG_HEIGHT=120 IMG_HEIGHT=120
HORIZON=20 HORIZON=20
TRAINING_DIR=~/src/robocars/trainings TRAINING_DATA_DIR=/tmp/data
TRAINING_DATA_DIR=${TRAINING_DIR}/data TRAINING_OUTPUT_DIR=/tmp/output
TRAINING_OUTPUT_DIR=${TRAINING_DIR}/output
TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip
####################### #######################
@ -32,9 +31,9 @@ go run ./cmd/rc-tools training archive \
printf "\n\nRun training\n\n" printf "\n\nRun training\n\n"
podman run --rm -it \ podman run --rm -it \
-v /trainings/data:/opt/ml/input/data/train \ -v /tmp/data:/opt/ml/input/data/train \
-v /trainings/output:/opt/ml/model/ \ -v /tmp/output:/opt/ml/model/ \
localhost/tensorflow_without_gpu:old \ localhost/tensorflow_without_gpu \
python /opt/ml/code/train.py \ python /opt/ml/code/train.py \
--img_height=${IMG_HEIGHT} \ --img_height=${IMG_HEIGHT} \
--img_width=${IMG_WIDTH} \ --img_width=${IMG_WIDTH} \