2020-01-28 22:18:35 +00:00
package main
import (
2021-10-17 17:15:44 +00:00
"context"
2020-01-28 22:18:35 +00:00
"flag"
"fmt"
"github.com/cyrilix/robocar-base/cli"
2021-10-17 17:15:44 +00:00
"github.com/cyrilix/robocar-tools/dkimpt"
2020-01-28 22:18:35 +00:00
"github.com/cyrilix/robocar-tools/part"
2021-10-17 17:15:44 +00:00
"github.com/cyrilix/robocar-tools/pkg/data"
2021-11-24 18:31:16 +00:00
"github.com/cyrilix/robocar-tools/pkg/models"
2021-10-17 17:15:44 +00:00
"github.com/cyrilix/robocar-tools/pkg/train"
2020-01-28 22:18:35 +00:00
"github.com/cyrilix/robocar-tools/record"
"github.com/cyrilix/robocar-tools/video"
mqtt "github.com/eclipse/paho.mqtt.golang"
2021-10-17 17:15:44 +00:00
"go.uber.org/zap"
"log"
2020-01-28 22:18:35 +00:00
"os"
)
const (
2021-10-17 17:15:44 +00:00
DefaultClientId = "robocar-tools"
2020-02-16 18:14:38 +00:00
DefaultTrainSliceSize = 0
2020-01-28 22:18:35 +00:00
)
func main ( ) {
var mqttBroker , username , password , clientId string
var framePath string
var fps int
var frameTopic , objectsTopic , roadTopic , recordTopic string
var withObjects , withRoad bool
2020-02-16 18:14:38 +00:00
var recordsPath string
var trainArchiveName string
var trainSliceSize int
2021-10-17 17:15:44 +00:00
var bucket , ociImage string
var debug bool
2020-01-28 22:18:35 +00:00
mqttQos := cli . InitIntFlag ( "MQTT_QOS" , 0 )
_ , mqttRetain := os . LookupEnv ( "MQTT_RETAIN" )
flag . Usage = func ( ) {
fmt . Printf ( "Usage of %s:\n" , os . Args [ 0 ] )
fmt . Printf ( " display\n \tDisplay events on live frames\n" )
fmt . Printf ( " record \n \tRecord event for tensorflow training\n" )
2021-10-17 17:15:44 +00:00
fmt . Printf ( " training \n \tManage training\n" )
2021-11-24 18:31:16 +00:00
fmt . Printf ( " models \n \tManage models\n" )
2021-10-17 17:15:44 +00:00
fmt . Printf ( " import-donkey-records \n \tCopy donkeycar records to new format\n" )
2020-01-28 22:18:35 +00:00
}
2021-10-17 17:15:44 +00:00
err := cli . SetIntDefaultValueFromEnv ( & trainSliceSize , "RC_TRAIN_SLICE_SIZE" , DefaultTrainSliceSize )
if err != nil {
log . Printf ( "unable to init TRAIN_SLICE_SIZE: %v" , err )
}
cli . SetDefaultValueFromEnv ( & ociImage , "TRAIN_OCI_IMAGE" , "" )
cli . SetDefaultValueFromEnv ( & bucket , "TRAIN_BUCKET" , "" )
flag . BoolVar ( & debug , "debug" , false , "Display debug logs" )
2020-01-28 22:18:35 +00:00
displayFlags := flag . NewFlagSet ( "display" , flag . ExitOnError )
cli . InitMqttFlagSet ( displayFlags , DefaultClientId , & mqttBroker , & username , & password , & clientId , & mqttQos , & mqttRetain )
displayFlags . StringVar ( & frameTopic , "mqtt-topic-frame" , os . Getenv ( "MQTT_TOPIC_FRAME" ) , "Mqtt topic that contains frame to display, use MQTT_TOPIC_FRAME if args not set" )
displayFlags . StringVar ( & framePath , "frame-path" , "" , "Directory path where to read jpeg frame to inject in frame topic" )
displayFlags . IntVar ( & fps , "frame-per-second" , 25 , "Video frame per second of frame to publish" )
displayFlags . StringVar ( & objectsTopic , "mqtt-topic-objects" , os . Getenv ( "MQTT_TOPIC_OBJECTS" ) , "Mqtt topic that contains detected objects, use MQTT_TOPIC_OBJECTS if args not set" )
displayFlags . BoolVar ( & withObjects , "with-objects" , false , "Display detected objects" )
displayFlags . StringVar ( & roadTopic , "mqtt-topic-road" , os . Getenv ( "MQTT_TOPIC_ROAD" ) , "Mqtt topic that contains road description, use MQTT_TOPIC_ROAD if args not set" )
displayFlags . BoolVar ( & withRoad , "with-road" , false , "Display detected road" )
recordFlags := flag . NewFlagSet ( "record" , flag . ExitOnError )
cli . InitMqttFlagSet ( recordFlags , DefaultClientId , & mqttBroker , & username , & password , & clientId , & mqttQos , & mqttRetain )
recordFlags . StringVar ( & recordTopic , "mqtt-topic-records" , os . Getenv ( "MQTT_TOPIC_RECORDS" ) , "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set" )
2020-02-16 18:14:38 +00:00
recordFlags . StringVar ( & recordsPath , "record-path" , os . Getenv ( "RECORD_PATH" ) , "Path where to write records files, use RECORD_PATH if args not set" )
2021-10-17 17:15:44 +00:00
var basedir , destdir string
impdkFlags := flag . NewFlagSet ( "import-donkey-records" , flag . ExitOnError )
impdkFlags . StringVar ( & basedir , "from" , " ", " source directory " )
impdkFlags . StringVar ( & destdir , "to" , "" , "destination directory" )
trainingFlags := flag . NewFlagSet ( "training" , flag . ExitOnError )
trainingFlags . Usage = func ( ) {
fmt . Printf ( "Usage of %s %s:\n" , os . Args [ 0 ] , trainingFlags . Name ( ) )
fmt . Printf ( " list\n \tList existing training jobs\n" )
fmt . Printf ( " archive\n \tBuild tar.gz archive for training\n" )
fmt . Printf ( " run\n \tRun training job\n" )
2020-02-16 18:14:38 +00:00
}
2021-10-17 17:15:44 +00:00
var modelPath , roleArn , trainJobName string
2021-11-24 18:31:16 +00:00
var withFlipImage bool
var trainImageHeight , trainImageWidth int
var enableSpotTraining bool
2021-10-17 17:15:44 +00:00
trainingRunFlags := flag . NewFlagSet ( "run" , flag . ExitOnError )
trainingRunFlags . StringVar ( & bucket , "bucket" , os . Getenv ( "RC_TRAIN_BUCKET" ) , "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set" )
trainingRunFlags . StringVar ( & recordsPath , "record-path" , os . Getenv ( "RECORD_PATH" ) , "Input data path where records and img files are stored, use RECORD_PATH if arg not set" )
trainingRunFlags . StringVar ( & modelPath , "output-model-path" , "" , "Path where to write output model archive" )
trainingRunFlags . IntVar ( & trainSliceSize , "slice-size" , trainSliceSize , "Number of record to shift with image, use RC_TRAIN_SLICE_SIZE if args not set" )
trainingRunFlags . StringVar ( & ociImage , "oci-image" , os . Getenv ( "RC_TRAIN_OCI_IMAGE" ) , "OCI image to run (required), use RC_TRAIN_OCI_IMAGE if args not set" )
trainingRunFlags . StringVar ( & roleArn , "role-arn" , os . Getenv ( "RC_TRAIN_ROLE" ) , "AWS ARN role to use to run training (required), use RC_TRAIN_ROLE if arg not set" )
trainingRunFlags . StringVar ( & trainJobName , "job-name" , "" , "Training job name (required)" )
2021-11-24 18:31:16 +00:00
trainingRunFlags . BoolVar ( & withFlipImage , "with-flip-image" , withFlipImage , "Flip horiontal image and reverse steering to increase data into training archive" )
2021-10-17 17:15:44 +00:00
2021-11-24 18:31:16 +00:00
trainingRunFlags . IntVar ( & trainImageHeight , "image-height" , 128 , "Pixels image height" )
trainingRunFlags . IntVar ( & trainImageWidth , "image-width" , 160 , "Pixels image width" )
trainingRunFlags . BoolVar ( & enableSpotTraining , "enable-spot-training" , true , "Train models using managed spot training" )
2021-10-17 17:15:44 +00:00
trainingListJobFlags := flag . NewFlagSet ( "list" , flag . ExitOnError )
trainArchiveFlags := flag . NewFlagSet ( "archive" , flag . ExitOnError )
2020-02-16 18:14:38 +00:00
trainArchiveFlags . StringVar ( & recordsPath , "record-path" , os . Getenv ( "RECORD_PATH" ) , "Path where records files are stored, use RECORD_PATH if args not set" )
trainArchiveFlags . StringVar ( & trainArchiveName , "output" , os . Getenv ( "TRAIN_ARCHIVE_NAME" ) , "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set" )
trainArchiveFlags . IntVar ( & trainSliceSize , "slice-size" , trainSliceSize , "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set" )
2021-11-24 18:31:16 +00:00
trainArchiveFlags . BoolVar ( & withFlipImage , "with-flip-image" , withFlipImage , "Flip horiontal image and reverse steering to increase data into training archive" )
modelsFlags := flag . NewFlagSet ( "models" , flag . ExitOnError )
modelsFlags . Usage = func ( ) {
fmt . Printf ( "Usage of %s %s:\n" , os . Args [ 0 ] , modelsFlags . Name ( ) )
fmt . Printf ( " list\n \tList existing models\n" )
fmt . Printf ( " download\n \tDownload existing models\n" )
}
2020-01-28 22:18:35 +00:00
2021-11-24 18:31:16 +00:00
modelsListFlags := flag . NewFlagSet ( "list" , flag . ExitOnError )
modelsListFlags . StringVar ( & bucket , "bucket" , os . Getenv ( "RC_TRAIN_BUCKET" ) , "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set" )
var modelPathBucket string
modelsDownloadFlags := flag . NewFlagSet ( "download" , flag . ExitOnError )
modelsDownloadFlags . StringVar ( & bucket , "bucket" , os . Getenv ( "RC_TRAIN_BUCKET" ) , "AWS bucket where store data required, use RC_TRAIN_BUCKET if arg not set" )
modelsDownloadFlags . StringVar ( & modelPathBucket , "model" , "" , "S3 Model key into bucket (mandatory)" )
modelsDownloadFlags . StringVar ( & trainArchiveName , "output" , os . Getenv ( "TRAIN_ARCHIVE_NAME" ) , "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set" )
2020-01-28 22:18:35 +00:00
flag . Parse ( )
2021-10-17 17:15:44 +00:00
config := zap . NewDevelopmentConfig ( )
if debug {
config . Level = zap . NewAtomicLevelAt ( zap . DebugLevel )
} else {
config . Level = zap . NewAtomicLevelAt ( zap . InfoLevel )
}
lgr , err := config . Build ( )
if err != nil {
log . Fatalf ( "unable to init logger: %v" , err )
}
defer func ( ) {
if err := lgr . Sync ( ) ; err != nil {
log . Printf ( "unable to Sync logger: %v\n" , err )
}
} ( )
zap . ReplaceGlobals ( lgr )
2020-01-28 22:18:35 +00:00
// Switch on the subcommand
// Parse the flags for appropriate FlagSet
// FlagSet.Parse() requires a set of arguments to parse as input
// os.Args[2:] will be all arguments starting after the subcommand at os.Args[1]
switch flag . Arg ( 0 ) {
case displayFlags . Name ( ) :
if err := displayFlags . Parse ( os . Args [ 2 : ] ) ; err == flag . ErrHelp {
displayFlags . PrintDefaults ( )
os . Exit ( 0 )
}
client , err := cli . Connect ( mqttBroker , username , password , clientId )
if err != nil {
2021-10-17 17:15:44 +00:00
zap . S ( ) . Fatalf ( "unable to connect to mqtt bus: %v" , err )
2020-01-28 22:18:35 +00:00
}
defer client . Disconnect ( 50 )
runDisplay ( client , framePath , frameTopic , fps , objectsTopic , roadTopic , withObjects , withRoad )
case recordFlags . Name ( ) :
if err := recordFlags . Parse ( os . Args [ 2 : ] ) ; err == flag . ErrHelp {
recordFlags . PrintDefaults ( )
os . Exit ( 0 )
}
client , err := cli . Connect ( mqttBroker , username , password , clientId )
if err != nil {
log . Fatalf ( "unable to connect to mqtt bus: %v" , err )
}
defer client . Disconnect ( 50 )
2020-02-16 18:14:38 +00:00
runRecord ( client , recordsPath , recordTopic )
2021-10-17 17:15:44 +00:00
case impdkFlags . Name ( ) :
if err := impdkFlags . Parse ( os . Args [ 2 : ] ) ; err == flag . ErrHelp {
impdkFlags . PrintDefaults ( )
os . Exit ( 0 )
}
runImportDonkeyRecords ( basedir , destdir )
case trainingFlags . Name ( ) :
if err := trainingFlags . Parse ( os . Args [ 2 : ] ) ; err == flag . ErrHelp {
trainingFlags . PrintDefaults ( )
os . Exit ( 0 )
}
switch trainingFlags . Arg ( 0 ) {
case trainingListJobFlags . Name ( ) :
if err := trainingListJobFlags . Parse ( os . Args [ 3 : ] ) ; err == flag . ErrHelp {
trainingListJobFlags . PrintDefaults ( )
os . Exit ( 0 )
}
runTrainList ( )
case trainingRunFlags . Name ( ) :
if err := trainingRunFlags . Parse ( os . Args [ 3 : ] ) ; err == flag . ErrHelp {
trainingRunFlags . PrintDefaults ( )
os . Exit ( 0 )
}
2021-11-24 18:31:16 +00:00
runTraining ( bucket , ociImage , roleArn , trainJobName , recordsPath , trainSliceSize , withFlipImage , modelPath ,
trainImageHeight , trainImageWidth , enableSpotTraining )
2021-10-17 17:15:44 +00:00
case trainArchiveFlags . Name ( ) :
if err := trainArchiveFlags . Parse ( os . Args [ 3 : ] ) ; err == flag . ErrHelp {
trainArchiveFlags . PrintDefaults ( )
os . Exit ( 0 )
}
2021-11-24 18:31:16 +00:00
runTrainArchive ( recordsPath , trainArchiveName , trainSliceSize , withFlipImage )
default :
trainingFlags . PrintDefaults ( )
os . Exit ( 0 )
2021-10-17 17:15:44 +00:00
2021-11-24 18:31:16 +00:00
}
case modelsFlags . Name ( ) :
2021-10-17 17:15:44 +00:00
2021-11-24 18:31:16 +00:00
if err := modelsFlags . Parse ( os . Args [ 2 : ] ) ; err == flag . ErrHelp {
modelsFlags . PrintDefaults ( )
os . Exit ( 0 )
}
switch modelsFlags . Arg ( 0 ) {
case modelsListFlags . Name ( ) :
if err := modelsListFlags . Parse ( os . Args [ 3 : ] ) ; err == flag . ErrHelp {
modelsListFlags . PrintDefaults ( )
os . Exit ( 0 )
}
runModelsList ( bucket )
case modelsDownloadFlags . Name ( ) :
if err := modelsDownloadFlags . Parse ( os . Args [ 3 : ] ) ; err == flag . ErrHelp {
modelsDownloadFlags . PrintDefaults ( )
os . Exit ( 0 )
}
if trainArchiveName == "" {
zap . S ( ) . Error ( "output model file is mandatory" )
modelsDownloadFlags . PrintDefaults ( )
os . Exit ( 1 )
}
runModelsDownload ( bucket , modelPathBucket , trainArchiveName )
2021-10-17 17:15:44 +00:00
default :
2021-11-24 18:31:16 +00:00
modelsFlags . PrintDefaults ( )
2020-02-16 18:14:38 +00:00
os . Exit ( 0 )
}
2021-10-17 17:15:44 +00:00
2020-01-28 22:18:35 +00:00
default :
flag . PrintDefaults ( )
os . Exit ( 1 )
}
}
2020-02-16 18:14:38 +00:00
func runRecord ( client mqtt . Client , recordsDir , recordTopic string ) {
2020-01-28 22:18:35 +00:00
2020-02-16 18:14:38 +00:00
r , err := record . New ( client , recordsDir , recordTopic )
if err != nil {
2021-10-17 17:15:44 +00:00
zap . S ( ) . Fatalf ( "unable to init record part: %v" , err )
2020-02-16 18:14:38 +00:00
}
2020-01-28 22:18:35 +00:00
defer r . Stop ( )
cli . HandleExit ( r )
2020-02-16 18:14:38 +00:00
err = r . Start ( )
2020-01-28 22:18:35 +00:00
if err != nil {
2021-10-17 17:15:44 +00:00
zap . S ( ) . Fatalf ( "unable to start service: %v" , err )
2020-01-28 22:18:35 +00:00
}
}
2021-11-24 18:31:16 +00:00
func runTrainArchive ( basedir , archiveName string , sliceSize int , withFlipImage bool ) {
2020-02-16 18:14:38 +00:00
2021-11-24 18:31:16 +00:00
err := data . WriteArchive ( basedir , archiveName , sliceSize , withFlipImage )
2021-10-17 17:15:44 +00:00
if err != nil {
zap . S ( ) . Fatalf ( "unable to build archive file %v: %v" , archiveName , err )
}
}
func runImportDonkeyRecords ( basedir , destdir string ) {
if destdir == "" || basedir == "" {
zap . S ( ) . Fatal ( "invalid arg" )
}
err := dkimpt . ImportDonkeyRecords ( basedir , destdir )
2020-02-16 18:14:38 +00:00
if err != nil {
2021-10-17 17:15:44 +00:00
zap . S ( ) . Fatalf ( "unable to import files from %v to %v: %v" , basedir , destdir , err )
2020-02-16 18:14:38 +00:00
}
}
2020-01-28 22:18:35 +00:00
func runDisplay ( client mqtt . Client , framePath string , frameTopic string , fps int , objectsTopic string , roadTopic string , withObjects bool , withRoad bool ) {
if framePath != "" {
camera , err := video . NewCameraFake ( client , frameTopic , framePath , fps )
if err != nil {
log . Fatalf ( "unable to load fake camera: %v" , err )
}
if err = camera . Start ( ) ; err != nil {
log . Fatalf ( "unable to start fake camera: %v" , err )
}
defer camera . Stop ( )
}
p := part . NewPart ( client , frameTopic ,
objectsTopic , roadTopic ,
withObjects , withRoad )
defer p . Stop ( )
cli . HandleExit ( p )
err := p . Start ( )
if err != nil {
2021-10-17 17:15:44 +00:00
zap . S ( ) . Fatalf ( "unable to start service: %v" , err )
2020-01-28 22:18:35 +00:00
}
}
2021-10-17 17:15:44 +00:00
2021-11-24 18:31:16 +00:00
func runTraining ( bucketName string , ociImage string , roleArn string , jobName , dataDir string , sliceSize int , withFlipImage bool ,
outputModel string , imgHeight int , imgWidth int , enableSpotTraining bool ) {
2021-10-17 17:15:44 +00:00
l := zap . S ( )
if bucketName == "" {
l . Fatalf ( "no bucket define, see help" )
}
if ociImage == "" {
l . Fatalf ( "no oci image define, see help" )
}
if jobName == "" {
l . Fatalf ( "no job name define, see help" )
}
if dataDir == "" {
l . Fatalf ( "no training data define, see help" )
}
if outputModel == "" {
l . Fatalf ( "no output model path define, see help" )
}
if sliceSize != 0 && sliceSize != 2 {
l . Fatalf ( "invalid value for sie-slice, only '0' or '2' are allowed" )
}
training := train . New ( bucketName , ociImage , roleArn )
2021-11-24 18:31:16 +00:00
err := training . TrainDir ( context . Background ( ) , jobName , dataDir , imgHeight , imgWidth , sliceSize , withFlipImage ,
outputModel , enableSpotTraining )
2021-10-17 17:15:44 +00:00
if err != nil {
l . Fatalf ( "unable to run training: %v" , err )
}
}
func runTrainList ( ) {
err := train . ListJob ( context . Background ( ) )
if err != nil {
zap . S ( ) . Fatalf ( "unable to list training jobs: %w" , err )
}
2021-11-24 18:31:16 +00:00
}
func runModelsList ( bucketName string ) {
err := models . ListModels ( context . Background ( ) , bucketName )
if err != nil {
zap . S ( ) . Fatalf ( "unable to list models: %s" , err )
}
}
func runModelsDownload ( bucketName , modelPath , output string ) {
err := models . DownloadArchiveToFile ( context . Background ( ) , bucketName , modelPath , output )
if err != nil {
zap . S ( ) . Fatalf ( "unable to download model: %s" , err )
}
2021-10-17 17:15:44 +00:00
}