Compare commits

..

No commits in common. "master" and "feat/object_detection" have entirely different histories.

4 changed files with 51 additions and 131 deletions

View File

@ -28,8 +28,8 @@ func main() {
var mqttBroker, username, password, clientId string var mqttBroker, username, password, clientId string
var framePath string var framePath string
var fps int var fps int
var frameTopic, objectsTopic, roadTopic, recordTopic, throttleFeedbackTopic string var frameTopic, objectsTopic, roadTopic, recordTopic string
var withObjects, withRoad, withThrottleFeedback bool var withObjects, withRoad bool
var recordsPath string var recordsPath string
var trainArchiveName string var trainArchiveName string
var trainSliceSize int var trainSliceSize int
@ -81,9 +81,6 @@ func main() {
displayCameraFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set") displayCameraFlags.StringVar(&roadTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_ROAD"), "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set")
displayCameraFlags.BoolVar(&withRoad, "with-road", false, "Display detected road") displayCameraFlags.BoolVar(&withRoad, "with-road", false, "Display detected road")
displayCameraFlags.StringVar(&throttleFeedbackTopic, "mqtt-topic-throttle-feedback", os.Getenv("MQTT_TOPIC_THROTTLE_FEEDBACK"), "Mqtt topic where to publish throttle feedback, use MQTT_TOPIC_THROTTLE_FEEDBACK if args not set")
displayCameraFlags.BoolVar(&withThrottleFeedback, "with-throttle-feedback", false, "Display throttle feedback")
recordFlags := flag.NewFlagSet("record", flag.ExitOnError) recordFlags := flag.NewFlagSet("record", flag.ExitOnError)
cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
recordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set") recordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set")
@ -102,7 +99,7 @@ func main() {
fmt.Printf(" run\n \tRun training job\n") fmt.Printf(" run\n \tRun training job\n")
} }
var modelPath, roleArn, trainJobName, modelType string var modelPath, roleArn, trainJobName string
var horizon int var horizon int
var withFlipImage bool var withFlipImage bool
var trainImageHeight, trainImageWidth int var trainImageHeight, trainImageWidth int
@ -120,7 +117,6 @@ func main() {
trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height") trainingRunFlags.IntVar(&trainImageHeight, "image-height", 128, "Pixels image height")
trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width") trainingRunFlags.IntVar(&trainImageWidth, "image-width", 160, "Pixels image width")
trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)") trainingRunFlags.IntVar(&horizon, "horizon", 0, "Upper zone image to crop (in pixels)")
trainingRunFlags.StringVar(&modelType, "model-type", train.ModelTypeCategorical.String(), "Type model to build")
trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training") trainingRunFlags.BoolVar(&enableSpotTraining, "enable-spot-training", true, "Train models using managed spot training")
trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError) trainingListJobFlags := flag.NewFlagSet("list", flag.ExitOnError)
@ -199,7 +195,7 @@ func main() {
zap.S().Fatalf("unable to connect to mqtt bus: %v", err) zap.S().Fatalf("unable to connect to mqtt bus: %v", err)
} }
defer client.Disconnect(50) defer client.Disconnect(50)
runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, throttleFeedbackTopic, withObjects, withRoad, withThrottleFeedback) runDisplay(client, framePath, frameTopic, fps, objectsTopic, roadTopic, withObjects, withRoad)
default: default:
displayFlags.PrintDefaults() displayFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
@ -238,7 +234,7 @@ func main() {
trainingRunFlags.PrintDefaults() trainingRunFlags.PrintDefaults()
os.Exit(0) os.Exit(0)
} }
runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, train.ParseModelType(modelType), trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining) runTraining(bucket, ociImage, roleArn, trainJobName, recordsPath, trainSliceSize, trainImageWidth, trainImageHeight, horizon, withFlipImage, modelPath, enableSpotTraining)
case trainArchiveFlags.Name(): case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp { if err := trainArchiveFlags.Parse(os.Args[3:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults() trainArchiveFlags.PrintDefaults()
@ -329,8 +325,7 @@ func runDisplayRecord(client mqtt.Client, recordTopic string) {
zap.S().Fatalf("unable to start service: %v", err) zap.S().Fatalf("unable to start service: %v", err)
} }
} }
func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic, roadTopic, throttleFeedbackTopic string, func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) {
withObjects, withRoad, withThrottleFeedback bool) {
if framePath != "" { if framePath != "" {
camera, err := video.NewCameraFake(client, frameTopic, framePath, fps) camera, err := video.NewCameraFake(client, frameTopic, framePath, fps)
@ -344,8 +339,8 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
} }
p := part.NewPart(client, frameTopic, p := part.NewPart(client, frameTopic,
objectsTopic, roadTopic, throttleFeedbackTopic, objectsTopic, roadTopic,
withObjects, withRoad, withThrottleFeedback) withObjects, withRoad)
defer p.Stop() defer p.Stop()
cli.HandleExit(p) cli.HandleExit(p)
@ -356,7 +351,7 @@ func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int
} }
} }
func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelType train.ModelType, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) { func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, sliceSize, imgWidth, imgHeight int, horizon int, withFlipImage bool, outputModel string, enableSpotTraining bool) {
l := zap.S() l := zap.S()
if bucketName == "" { if bucketName == "" {
@ -379,12 +374,8 @@ func runTraining(bucketName, ociImage, roleArn, jobName, dataDir string, modelTy
l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed") l.Fatalf("invalid value for sie-slice, only '0' or '2' are allowed")
} }
if modelType == train.ModelTypeUnknown {
l.Fatalf("invalid model type: %v", modelType)
}
training := train.New(bucketName, ociImage, roleArn) training := train.New(bucketName, ociImage, roleArn)
err := training.TrainDir(context.Background(), jobName, dataDir, modelType, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining) err := training.TrainDir(context.Background(), jobName, dataDir, imgWidth, imgHeight, sliceSize, horizon, withFlipImage, outputModel, enableSpotTraining)
if err != nil { if err != nil {
l.Fatalf("unable to run training: %v", err) l.Fatalf("unable to run training: %v", err)

View File

@ -14,41 +14,35 @@ import (
"time" "time"
) )
func NewPart(client mqtt.Client, frameTopic, objectsTopic, roadTopic, throttleFeedbackTopic string, func NewPart(client mqtt.Client, frameTopic, objectsTopic, roadTopic string, withObjects, withRoad bool) *FramePart {
withObjects, withRoad, withThrottleFeedback bool) *FramePart {
return &FramePart{ return &FramePart{
client: client, client: client,
frameTopic: frameTopic, frameTopic: frameTopic,
objectsTopic: objectsTopic, objectsTopic: objectsTopic,
roadTopic: roadTopic, roadTopic: roadTopic,
throttleFeedbackTopic: throttleFeedbackTopic, window: gocv.NewWindow("frameTopic"),
window: gocv.NewWindow("frameTopic"), withObjects: withObjects,
withObjects: withObjects, withRoad: withRoad,
withRoad: withRoad, imgChan: make(chan gocv.Mat),
withThrottleFeedback: withThrottleFeedback, objectsChan: make(chan events.ObjectsMessage),
imgChan: make(chan gocv.Mat), roadChan: make(chan events.RoadMessage),
objectsChan: make(chan events.ObjectsMessage), cancel: make(chan interface{}),
roadChan: make(chan events.RoadMessage),
throttleFeedbackChan: make(chan events.ThrottleMessage),
cancel: make(chan interface{}),
} }
} }
type FramePart struct { type FramePart struct {
client mqtt.Client client mqtt.Client
frameTopic, objectsTopic, roadTopic, throttleFeedbackTopic string frameTopic, objectsTopic, roadTopic string
window *gocv.Window window *gocv.Window
withObjects bool withObjects bool
withRoad bool withRoad bool
withThrottleFeedback bool
imgChan chan gocv.Mat imgChan chan gocv.Mat
objectsChan chan events.ObjectsMessage objectsChan chan events.ObjectsMessage
roadChan chan events.RoadMessage roadChan chan events.RoadMessage
throttleFeedbackChan chan events.ThrottleMessage cancel chan interface{}
cancel chan interface{}
} }
func (p *FramePart) Start() error { func (p *FramePart) Start() error {
@ -59,8 +53,6 @@ func (p *FramePart) Start() error {
var img = gocv.NewMat() var img = gocv.NewMat()
var objectsMsg events.ObjectsMessage var objectsMsg events.ObjectsMessage
var roadMsg events.RoadMessage var roadMsg events.RoadMessage
var throttleFeedbackMsg events.ThrottleMessage
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
for { for {
select { select {
@ -73,13 +65,11 @@ func (p *FramePart) Start() error {
objectsMsg = objects objectsMsg = objects
case road := <-p.roadChan: case road := <-p.roadChan:
roadMsg = road roadMsg = road
case throttleFeedback := <-p.throttleFeedbackChan:
throttleFeedbackMsg = throttleFeedback
case <-p.cancel: case <-p.cancel:
img.Close() img.Close()
return nil return nil
} }
p.drawFrame(&img, &objectsMsg, &roadMsg, &throttleFeedbackMsg) p.drawFrame(&img, &objectsMsg, &roadMsg)
ticker.Reset(1 * time.Second) ticker.Reset(1 * time.Second)
} }
} }
@ -135,18 +125,6 @@ func (p *FramePart) onRoad(_ mqtt.Client, message mqtt.Message) {
p.roadChan <- msg p.roadChan <- msg
} }
func (p *FramePart) onThrottleFeedback(_ mqtt.Client, message mqtt.Message) {
var msg events.ThrottleMessage
err := proto.Unmarshal(message.Payload(), &msg)
if err != nil {
zap.S().Errorf("unable to unmarshal msg %T: %v", msg, err)
return
}
p.throttleFeedbackChan <- msg
}
func (p *FramePart) registerCallbacks() error { func (p *FramePart) registerCallbacks() error {
err := RegisterCallback(p.client, p.frameTopic, p.onFrame) err := RegisterCallback(p.client, p.frameTopic, p.onFrame)
if err != nil { if err != nil {
@ -166,16 +144,10 @@ func (p *FramePart) registerCallbacks() error {
return err return err
} }
} }
if p.withThrottleFeedback {
err := service.RegisterCallback(p.client, p.throttleFeedbackTopic, p.onThrottleFeedback)
if err != nil {
return err
}
}
return nil return nil
} }
func (p *FramePart) drawFrame(img *gocv.Mat, objects *events.ObjectsMessage, road *events.RoadMessage, tf *events.ThrottleMessage) { func (p *FramePart) drawFrame(img *gocv.Mat, objects *events.ObjectsMessage, road *events.RoadMessage) {
if p.withObjects { if p.withObjects {
p.drawObjects(img, objects) p.drawObjects(img, objects)
@ -183,9 +155,6 @@ func (p *FramePart) drawFrame(img *gocv.Mat, objects *events.ObjectsMessage, roa
if p.withRoad { if p.withRoad {
p.drawRoad(img, road) p.drawRoad(img, road)
} }
if p.withThrottleFeedback {
p.drawThrottleFeedbackText(img, tf)
}
p.window.IMShow(*img) p.window.IMShow(*img)
p.window.WaitKey(1) p.window.WaitKey(1)
@ -247,18 +216,6 @@ func (p *FramePart) drawRoadText(img *gocv.Mat, road *events.RoadMessage) {
) )
} }
func (p *FramePart) drawThrottleFeedbackText(img *gocv.Mat, tf *events.ThrottleMessage) {
gocv.PutText(
img,
fmt.Sprintf("Throttle feedback: %.3f", tf.Throttle),
image.Point{X: 5, Y: 20},
gocv.FontHersheyPlain,
0.6,
color.RGBA{R: 0, G: 255, B: 255, A: 255},
1,
)
}
func StopService(name string, client mqtt.Client, topics ...string) { func StopService(name string, client mqtt.Client, topics ...string) {
zap.S().Infof("Stop %s service", name) zap.S().Infof("Stop %s service", name)
token := client.Unsubscribe(topics...) token := client.Unsubscribe(topics...)

View File

@ -12,40 +12,9 @@ import (
"github.com/cyrilix/robocar-tools/pkg/models" "github.com/cyrilix/robocar-tools/pkg/models"
"go.uber.org/zap" "go.uber.org/zap"
"strconv" "strconv"
"strings"
"time" "time"
) )
type ModelType int
func ParseModelType(s string) ModelType {
switch strings.ToLower(s) {
case "categorical":
return ModelTypeCategorical
case "linear":
return ModelTypeLinear
default:
return ModelTypeUnknown
}
}
func (m ModelType) String() string {
switch m {
case ModelTypeCategorical:
return "categorical"
case ModelTypeLinear:
return "linear"
default:
return "unknown"
}
}
const (
ModelTypeUnknown ModelType = iota
ModelTypeCategorical
ModelTypeLinear
)
func New(bucketName string, ociImage, roleArn string) *Training { func New(bucketName string, ociImage, roleArn string) *Training {
return &Training{ return &Training{
config: awsutils.MustLoadConfig(), config: awsutils.MustLoadConfig(),
@ -66,7 +35,7 @@ type Training struct {
outputBucket string outputBucket string
} }
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error { func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
l.Infof("run training with data from %s", basedir) l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage) archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
@ -81,7 +50,14 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelT
} }
l.Info("") l.Info("")
err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType) err = t.runTraining(
ctx,
jobName,
sliceSize,
imgHeight,
imgWidth,
enableSpotTraining,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to run training: %w", err) return fmt.Errorf("unable to run training: %w", err)
} }
@ -119,7 +95,7 @@ func List(bucketName string) error {
return nil return nil
} }
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error { func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
l := zap.S() l := zap.S()
client := sagemaker.NewFromConfig(awsutils.MustLoadConfig()) client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
l.Infof("Start training job '%s'", jobName) l.Infof("Start training job '%s'", jobName)
@ -149,9 +125,6 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, i
"slide_size": strconv.Itoa(slideSize), "slide_size": strconv.Itoa(slideSize),
"img_height": strconv.Itoa(imgHeight), "img_height": strconv.Itoa(imgHeight),
"img_width": strconv.Itoa(imgWidth), "img_width": strconv.Itoa(imgWidth),
"batch_size": strconv.Itoa(32),
"model_type": modelType.String(),
"horizon": strconv.Itoa(horizon),
}, },
InputDataConfig: []types.Channel{ InputDataConfig: []types.Channel{
{ {

View File

@ -3,8 +3,8 @@
set +e set +e
set +x set +x
RECORDS_PATH=~/src/robocars/data/viva20/viva12/ RECORDS_PATH=~/robocar/record-sim4-2
#TRAINING_OPTS="--horizon=50" #TRAINING_OPTS="--horizon=20"
TRAINING_OPTS="" TRAINING_OPTS=""
MODEL_TYPE="categorical" MODEL_TYPE="categorical"
#MODEL_TYPE="linear" #MODEL_TYPE="linear"
@ -12,9 +12,8 @@ IMG_WIDTH=160
IMG_HEIGHT=120 IMG_HEIGHT=120
HORIZON=20 HORIZON=20
TRAINING_DIR=~/src/robocars/trainings TRAINING_DATA_DIR=/tmp/data
TRAINING_DATA_DIR=${TRAINING_DIR}/data TRAINING_OUTPUT_DIR=/tmp/output
TRAINING_OUTPUT_DIR=${TRAINING_DIR}/output
TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip TRAIN_ARCHIVE=${TRAINING_DATA_DIR}/train.zip
####################### #######################
@ -31,10 +30,10 @@ go run ./cmd/rc-tools training archive \
-image-width ${IMG_WIDTH} -image-width ${IMG_WIDTH}
printf "\n\nRun training\n\n" printf "\n\nRun training\n\n"
podman run --rm -it \ podman run --rm -it \
-v /trainings/data:/opt/ml/input/data/train \ -v /tmp/data:/opt/ml/input/data/train \
-v /trainings/output:/opt/ml/model/ \ -v /tmp/output:/opt/ml/model/ \
localhost/tensorflow_without_gpu:old \ localhost/tensorflow_without_gpu \
python /opt/ml/code/train.py \ python /opt/ml/code/train.py \
--img_height=${IMG_HEIGHT} \ --img_height=${IMG_HEIGHT} \
--img_width=${IMG_WIDTH} \ --img_width=${IMG_WIDTH} \