Compare commits

..

3 Commits

Author SHA1 Message Date
Cyrille Nofficial
b0e91a5994 feat(training): bash script to run training 2022-10-06 14:09:41 +02:00
Cyrille Nofficial
c4ff4b46b0 feat(training): add flag toconfigure model type training 2022-10-06 14:07:44 +02:00
b7b4bd76a0 feat(dsiplay): add option to display throttle feedback 2022-09-05 19:34:10 +02:00
4 changed files with 131 additions and 51 deletions

View File

@ -28,8 +28,8 @@ func main() {
var mqttBroker, username, password, clientId string
var framePath string
var fps int
var frameTopic, objectsTopic, roadTopic, recordTopic string
var withObjects, withRoad bool
var frameTopic, objectsTopic, roadTopic, recordTopic, throttleFeedbackTopic string
var withObjects, withRoad, withThrottleFeedback bool
var recordsPath string
var trainArchiveName string
var trainSliceSize int
@ -81,6 +81,9 @@ func main() {
displayCameraFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set")
displayCameraFlags.BoolVar(&withRoad, "with-road", false, "Display detected road")
displayCameraFlags.StringVar(&throttleFeedbackTopic, "mqtt-topic-throttle-feedback", os.Getenv("MQTT_TOPIC_THROTTLE_FEEDBACK"), "Mqtt topic where to publish throttle feedback, use MQTT_TOPIC_THROTTLE_FEEDBACK if args not set")
displayCameraFlags.BoolVar(&withThrottleFeedback, "with-throttle-feedback", false, "Display throttle feedback")
recordFlags := flag.NewFlagSet("record", flag.ExitOnError)
cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
recordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set")
@ -99,7 +102,7 @@ func main() {
fmt.Printf(" run\n \tRun training job\n")
}
var modelPath, roleArn, trainJobName string
var modelPath, roleArn, trainJobName, modelType string
var horizon int
var withFlipImage bool
var trainImageHeight, trainImageWidth int
@ -117,6 +120,7 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -195,7 +199,7 @@ func main() {
zap.S().Fatalf("unable to connect to mqtt bus: %v", err)
}
defer client.Disconnect(50)
runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, withObjects, withRoad)
runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, throttleFeedbackTopic, withObjects, withRoad, withThrottleFeedback)
default:
displayFlags.PrintDefaults()
os.Exit(0)
@ -234,7 +238,7 @@ func main() {
trainingRunFlags.PrintDefaults()
os.Exit(0)
}
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults()
@ -325,7 +329,8 @@ func runDisplayRecord(client mqtt.Client, recordTopic string) {
zap.S().Fatalf("unable to start service: %v", err)
}
}
func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) {
func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic, roadTopic, throttleFeedbackTopic string,
withObjects, withRoad, withThrottleFeedback bool) {
if framePath != "" {
camera, err := video.NewCameraFake(client, frameTopic, framePath, fps)
@ -339,8 +344,8 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
}
p := part.NewPart(client, frameTopic,
objectsTopic, roadTopic,
withObjects, withRoad)
objectsTopic, roadTopic, throttleFeedbackTopic,
withObjects, withRoad, withThrottleFeedback)
defer p.Stop()
cli.HandleExit(p)
@ -351,7 +356,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
}
}
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
l := zap.S()
if bucketName == "" {
@ -374,8 +379,12 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSi
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
}
if modelType == train.ModelTypeUnknown {
l.Fatalf("invalid model type: %v", modelType)
}
training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
if err != nil {
l.Fatalf("unable to run training: %v", err)

View File

@ -14,35 +14,41 @@ import (
"time"
)
func NewPart(client mqtt.Client, frameTopic, objectsTopic, roadTopic string, withObjects, withRoad bool) *FramePart {
func NewPart(client mqtt.Client, frameTopic, objectsTopic, roadTopic, throttleFeedbackTopic string,
withObjects, withRoad, withThrottleFeedback bool) *FramePart {
return &FramePart{
client: client,
frameTopic: frameTopic,
objectsTopic: objectsTopic,
roadTopic: roadTopic,
window: gocv.NewWindow("frameTopic"),
withObjects: withObjects,
withRoad: withRoad,
imgChan: make(chan gocv.Mat),
objectsChan: make(chan events.ObjectsMessage),
roadChan: make(chan events.RoadMessage),
cancel: make(chan interface{}),
client: client,
frameTopic: frameTopic,
objectsTopic: objectsTopic,
roadTopic: roadTopic,
throttleFeedbackTopic: throttleFeedbackTopic,
window: gocv.NewWindow("frameTopic"),
withObjects: withObjects,
withRoad: withRoad,
withThrottleFeedback: withThrottleFeedback,
imgChan: make(chan gocv.Mat),
objectsChan: make(chan events.ObjectsMessage),
roadChan: make(chan events.RoadMessage),
throttleFeedbackChan: make(chan events.ThrottleMessage),
cancel: make(chan interface{}),
}
}
type FramePart struct {
client mqtt.Client
frameTopic, objectsTopic, roadTopic string
client mqtt.Client
frameTopic, objectsTopic, roadTopic, throttleFeedbackTopic string
window *gocv.Window
withObjects bool
withRoad bool
window *gocv.Window
withObjects bool
withRoad bool
withThrottleFeedback bool
imgChan chan gocv.Mat
objectsChan chan events.ObjectsMessage
roadChan chan events.RoadMessage
cancel chan interface{}
imgChan chan gocv.Mat
objectsChan chan events.ObjectsMessage
roadChan chan events.RoadMessage
throttleFeedbackChan chan events.ThrottleMessage
cancel chan interface{}
}
func (p *FramePart) Start() error {
@ -53,6 +59,8 @@ func (p *FramePart) Start() error {
var img = gocv.NewMat()
var objectsMsg events.ObjectsMessage
var roadMsg events.RoadMessage
var throttleFeedbackMsg events.ThrottleMessage
ticker := time.NewTicker(1 * time.Second)
for {
select {
@ -65,11 +73,13 @@ func (p *FramePart) Start() error {
objectsMsg = objects
case road := <-p.roadChan:
roadMsg = road
case throttleFeedback := <-p.throttleFeedbackChan:
throttleFeedbackMsg = throttleFeedback
case <-p.cancel:
img.Close()
return nil
}
p.drawFrame(&img, &objectsMsg, &roadMsg)
p.drawFrame(&img, &objectsMsg, &roadMsg, &throttleFeedbackMsg)
ticker.Reset(1 * time.Second)
}
}
@ -125,6 +135,18 @@ func (p *FramePart) onRoad(_ mqtt.Client, message mqtt.Message) {
p.roadChan <- msg
}
func (p *FramePart) onThrottleFeedback(_ mqtt.Client, message mqtt.Message) {
var msg events.ThrottleMessage
err := proto.Unmarshal(message.Payload(), &msg)
if err != nil {
zap.S().Errorf("unable to unmarshal msg %T: %v", msg, err)
return
}
p.throttleFeedbackChan <- msg
}
func (p *FramePart) registerCallbacks() error {
err := RegisterCallback(p.client, p.frameTopic, p.onFrame)
if err != nil {
@ -144,10 +166,16 @@ func (p *FramePart) registerCallbacks() error {
return err
}
}
if p.withThrottleFeedback {
err := service.RegisterCallback(p.client, p.throttleFeedbackTopic, p.onThrottleFeedback)
if err != nil {
return err
}
}
return nil
}
func (p *FramePart) drawFrame(img *gocv.Mat, objects *events.ObjectsMessage, road *events.RoadMessage) {
func (p *FramePart) drawFrame(img *gocv.Mat, objects *events.ObjectsMessage, road *events.RoadMessage, tf *events.ThrottleMessage) {
if p.withObjects {
p.drawObjects(img, objects)
@ -155,6 +183,9 @@ func (p *FramePart) drawFrame(img *gocv.Mat, objects *events.ObjectsMessage, roa
if p.withRoad {
p.drawRoad(img, road)
}
if p.withThrottleFeedback {
p.drawThrottleFeedbackText(img, tf)
}
p.window.IMShow(*img)
p.window.WaitKey(1)
@ -216,6 +247,18 @@ func (p *FramePart) drawRoadText(img *gocv.Mat, road *events.RoadMessage) {
)
}
func (p *FramePart) drawThrottleFeedbackText(img *gocv.Mat, tf *events.ThrottleMessage) {
gocv.PutText(
img,
fmt.Sprintf("Throttle feedback: %.3f", tf.Throttle),
image.Point{X: 5, Y: 20},
gocv.FontHersheyPlain,
0.6,
color.RGBA{R: 0, G: 255, B: 255, A: 255},
1,
)
}
func StopService(name string, client mqtt.Client, topics ...string) {
zap.S().Infof("Stop %s service", name)
token := client.Unsubscribe(topics...)

View File

@ -12,9 +12,40 @@ import (
"github.com/cyrilix/robocar-tools/pkg/models"
"go.uber.org/zap"
"strconv"
"strings"
"time"
)
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
func New(bucketName string, ociImage, roleArn string) *Training {
return &Training{
config: awsutils.MustLoadConfig(),
@ -35,7 +66,7 @@ type Training struct {
outputBucket string
}
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S()
l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
}
l.Info("")
err = t.runTraining(
ctx,
jobName,
sliceSize,
imgHeight,
imgWidth,
enableSpotTraining,
)
err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
if err != nil {
return fmt.Errorf("unable to run training: %w", err)
}
@ -95,7 +119,7 @@ func List(bucketName string) error {
return nil
}
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
l := zap.S()
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
l.Infof("Start training job '%s'", jobName)
@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
"slide_size": strconv.Itoa(slideSize),
"img_height": strconv.Itoa(imgHeight),
"img_width": strconv.Itoa(imgWidth),
"batch_size": strconv.Itoa(32),
"model_type": modelType.String(),
"horizon": strconv.Itoa(horizon),
},
InputDataConfig: []types.Channel{
{

View File

@ -3,8 +3,8 @@
set +e
set +x
RECORDS_PATH=~/robocar/record-sim4-2
#TRAINING_OPTS="--horizon=20"
RECORDS_PATH=~/src/robocars/data/viva20/viva12/
#TRAINING_OPTS="--horizon=50"
TRAINING_OPTS=""
MODEL_TYPE="categorical"
#MODEL_TYPE="linear"
@ -12,8 +12,9 @@ IMG_WIDTH=160
IMG_HEIGHT=120
HORIZON=20
TRAINING_DATA_DIR=/tmp/data
TRAINING_OUTPUT_DIR=/tmp/output
TRAINING_DIR=~/src/robocars/trainings
TRAINING_DATA_DIR=${TRAINING_DIR}/data
TRAINING_OUTPUT_DIR=${TRAINING_DIR}/output
TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip
#######################
@ -30,10 +31,10 @@ go run ./cmd/rc-tools training archive \
-image-width ${IMG_WIDTH}
printf "\n\nRun training\n\n"
podman run --rm -it \
-v /tmp/data:/opt/ml/input/data/train \
-v /tmp/output:/opt/ml/model/ \
localhost/tensorflow_without_gpu \
podman run --rm -it \
-v /trainings/data:/opt/ml/input/data/train \
-v /trainings/output:/opt/ml/model/ \
localhost/tensorflow_without_gpu:old \
python /opt/ml/code/train.py \
--img_height=${IMG_HEIGHT} \
--img_width=${IMG_WIDTH} \