2021-10-10 19:03:57 +00:00
package main
import (
2021-12-02 08:16:50 +00:00
"context"
2021-10-10 19:03:57 +00:00
"flag"
2022-06-10 11:50:22 +00:00
"fmt"
2021-10-10 19:03:57 +00:00
"github.com/cyrilix/robocar-base/cli"
2021-12-02 08:16:50 +00:00
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
2023-05-05 15:07:29 +00:00
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/oci"
2021-10-10 19:03:57 +00:00
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
2022-06-10 11:50:22 +00:00
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
2021-10-10 19:03:57 +00:00
"go.uber.org/zap"
"log"
"os"
2022-06-10 11:50:22 +00:00
"regexp"
"strconv"
2021-10-10 19:03:57 +00:00
)
const (
DefaultClientId = "robocar-steering-tflite-edgetpu"
)
2022-06-10 11:50:22 +00:00
var (
modelNameRegex = regexp . MustCompile ( ".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$" )
)
2021-10-10 19:03:57 +00:00
func main ( ) {
var mqttBroker , username , password , clientId string
var cameraTopic , steeringTopic string
2023-05-28 12:39:49 +00:00
var modelPath , modelsDir , ociRegistry , ociRepository , ociTag string
2021-10-10 19:03:57 +00:00
var edgeVerbosity int
2021-11-25 16:12:16 +00:00
var imgWidth , imgHeight , horizon int
2021-10-10 19:03:57 +00:00
mqttQos := cli . InitIntFlag ( "MQTT_QOS" , 0 )
_ , mqttRetain := os . LookupEnv ( "MQTT_RETAIN" )
cli . InitMqttFlags ( DefaultClientId , & mqttBroker , & username , & password , & clientId , & mqttQos , & mqttRetain )
flag . StringVar ( & modelPath , "model" , "" , "path to model file" )
2023-05-28 12:39:49 +00:00
flag . StringVar ( & ociRegistry , "oci-model-registry" , "" , "oci registry where to fetch model" )
flag . StringVar ( & ociRepository , "oci-model-repository" , "" , "oci repository where to fetch model" )
flag . StringVar ( & ociTag , "oci-model-tag" , "" , "oci tag name for model to pull" )
2023-05-05 15:07:29 +00:00
flag . StringVar ( & modelsDir , "models-dir" , "/tmp/robocar/models" , "path where to store model file" )
2021-10-10 19:03:57 +00:00
flag . StringVar ( & steeringTopic , "mqtt-topic-road" , os . Getenv ( "MQTT_TOPIC_STEERING" ) , "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set" )
flag . StringVar ( & cameraTopic , "mqtt-topic-camera" , os . Getenv ( "MQTT_TOPIC_CAMERA" ) , "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set" )
flag . IntVar ( & edgeVerbosity , "edge-verbosity" , 0 , "Edge TPU Verbosity" )
2022-06-10 11:50:22 +00:00
flag . IntVar ( & imgWidth , "img-width" , 0 , "image width expected by model" )
flag . IntVar ( & imgHeight , "img-height" , 0 , "image height expected by model" )
2021-11-25 16:12:16 +00:00
flag . IntVar ( & horizon , "horizon" , 0 , "upper zone to crop from image. Models expect size 'imgHeight - horizon'" )
2021-12-18 16:57:31 +00:00
logLevel := zap . LevelFlag ( "log" , zap . InfoLevel , "log level" )
2021-10-10 19:03:57 +00:00
flag . Parse ( )
2021-12-18 16:57:31 +00:00
2021-10-10 19:03:57 +00:00
if len ( os . Args ) <= 1 {
flag . PrintDefaults ( )
os . Exit ( 1 )
}
2021-10-12 15:34:47 +00:00
config := zap . NewDevelopmentConfig ( )
2021-12-18 16:57:31 +00:00
config . Level = zap . NewAtomicLevelAt ( * logLevel )
2021-10-12 15:34:47 +00:00
lgr , err := config . Build ( )
if err != nil {
log . Fatalf ( "unable to init logger: %v" , err )
}
defer func ( ) {
if err := lgr . Sync ( ) ; err != nil {
log . Printf ( "unable to Sync logger: %v\n" , err )
}
} ( )
zap . ReplaceGlobals ( lgr )
2021-12-02 08:16:50 +00:00
cleanup := metrics . Init ( context . Background ( ) )
defer cleanup ( )
2023-05-28 12:39:49 +00:00
if modelPath == "" && ociRepository == "" {
2023-05-05 15:07:29 +00:00
zap . L ( ) . Error ( "model path or oci image is mandatory" )
2021-10-10 19:03:57 +00:00
flag . PrintDefaults ( )
os . Exit ( 1 )
}
2023-05-28 12:39:49 +00:00
if modelPath != "" && ociRepository != "" {
2023-05-05 15:07:29 +00:00
zap . L ( ) . Error ( "model path and oci image are exclusives" )
flag . PrintDefaults ( )
os . Exit ( 1 )
}
var modelType tools . ModelType
var width , height , horizonFromName int
if modelPath != "" {
modelType , width , height , horizonFromName , err = parseModelName ( modelPath )
if err != nil {
zap . S ( ) . Panicf ( "bad model name '%v', unable to detect configuration from name pattern: %v" , modelPath , err )
}
} else {
2023-05-28 12:39:49 +00:00
ctx := context . Background ( )
modelPath , modelType , width , height , horizonFromName , err = oci . PullOciImage ( ctx , ociRegistry , ociRepository , ociTag , modelsDir )
2023-05-05 15:07:29 +00:00
if err != nil {
zap . S ( ) . Panicf ( "bad model name '%v', unable to detect configuration from name pattern: %v" , modelPath , err )
}
2022-06-10 11:50:22 +00:00
}
2023-05-05 15:07:29 +00:00
2022-06-10 11:50:22 +00:00
if imgWidth == 0 {
imgWidth = width
}
if imgHeight == 0 {
imgHeight = height
}
if horizonFromName == 0 {
horizon = horizonFromName
}
2021-11-25 16:12:16 +00:00
if imgWidth <= 0 || imgHeight <= 0 {
zap . L ( ) . Error ( "img-width and img-height are mandatory" )
flag . PrintDefaults ( )
os . Exit ( 1 )
}
2021-10-10 19:03:57 +00:00
2023-05-28 12:39:49 +00:00
if ociRepository == "" {
2023-05-05 15:07:29 +00:00
zap . S ( ) . Infof ( "model path : %v" , modelPath )
} else {
2023-05-28 12:39:49 +00:00
zap . S ( ) . Infof ( "oci image model : %v/%v:%v" , ociRegistry , ociRepository , ociTag )
2023-05-05 15:07:29 +00:00
}
2022-06-10 11:50:22 +00:00
zap . S ( ) . Infof ( "model type : %v" , modelType )
zap . S ( ) . Infof ( "model for image width : %v" , imgWidth )
zap . S ( ) . Infof ( "model for image height: %v" , imgHeight )
zap . S ( ) . Infof ( "model with horizon : %v" , horizon )
2021-10-10 19:03:57 +00:00
client , err := cli . Connect ( mqttBroker , username , password , clientId )
if err != nil {
zap . L ( ) . Fatal ( "unable to connect to mqtt bus" , zap . Error ( err ) )
}
defer client . Disconnect ( 50 )
2022-06-10 11:50:22 +00:00
p := steering . NewPart ( client , modelType , modelPath , steeringTopic , cameraTopic , edgeVerbosity , imgWidth , imgHeight , horizon )
2021-10-10 19:03:57 +00:00
defer p . Stop ( )
cli . HandleExit ( p )
err = p . Start ( )
if err != nil {
zap . L ( ) . Fatal ( "unable to start service" , zap . Error ( err ) )
}
}
2022-06-10 11:50:22 +00:00
func parseModelName ( modelPath string ) ( modelType tools . ModelType , imgWidth , imgHeight int , horizon int , err error ) {
match := modelNameRegex . FindStringSubmatch ( modelPath )
results := map [ string ] string { }
for i , name := range match {
results [ modelNameRegex . SubexpNames ( ) [ i ] ] = name
}
modelType = tools . ParseModelType ( results [ "type" ] )
if modelType == tools . ModelTypeUnknown {
err = fmt . Errorf ( "unknown model type '%v'" , results [ "type" ] )
return
}
imgWidth , err = strconv . Atoi ( results [ "imgWidth" ] )
if err != nil {
err = fmt . Errorf ( "unable to convert image width '%v' to integer: %v" , results [ "imgWidth" ] , err )
return
}
imgHeight , err = strconv . Atoi ( results [ "imgHeight" ] )
if err != nil {
err = fmt . Errorf ( "unable to convert image height '%v' to integer: %v" , results [ "imgHeight" ] , err )
return
}
horizon , err = strconv . Atoi ( results [ "horizon" ] )
if err != nil {
err = fmt . Errorf ( "unable to convert horizon '%v' to integer: %v" , results [ "horizon" ] , err )
return
}
return
}