feat: load model from oci image

This commit is contained in:
2023-05-05 17:07:29 +02:00
parent b57698380e
commit 9fb01f7be9
441 changed files with 61395 additions and 15356 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/cyrilix/robocar-base/cli"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/oci"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"go.uber.org/zap"
@ -26,7 +27,7 @@ var (
func main() {
var mqttBroker, username, password, clientId string
var cameraTopic, steeringTopic string
var modelPath string
var modelPath, modelsDir, ociRef string
var edgeVerbosity int
var imgWidth, imgHeight, horizon int
@ -36,6 +37,8 @@ func main() {
cli.InitMqttFlags(DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
flag.StringVar(&modelPath, "model", "", "path to model file")
flag.StringVar(&ociRef, "oci-model", "", "oci image to pull")
flag.StringVar(&modelsDir, "models-dir", "/tmp/robocar/models", "path where to store model file")
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
@ -65,16 +68,33 @@ func main() {
cleanup := metrics.Init(context.Background())
defer cleanup()
if modelPath == "" {
zap.L().Error("model path is mandatory")
if modelPath == "" && ociRef == "" {
zap.L().Error("model path or oci image is mandatory")
flag.PrintDefaults()
os.Exit(1)
}
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
if err != nil {
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
if modelPath != "" && ociRef != "" {
zap.L().Error("model path and oci image are exclusives")
flag.PrintDefaults()
os.Exit(1)
}
var modelType tools.ModelType
var width, height, horizonFromName int
if modelPath != "" {
modelType, width, height, horizonFromName, err = parseModelName(modelPath)
if err != nil {
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
}
} else {
modelPath, modelType, width, height, horizonFromName, err = oci.PullOciImage(ociRef, modelsDir)
if err != nil {
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
}
}
if imgWidth == 0 {
imgWidth = width
}
@ -90,6 +110,11 @@ func main() {
os.Exit(1)
}
if ociRef == "" {
zap.S().Infof("model path : %v", modelPath)
} else {
zap.S().Infof("oci image model : %v", ociRef)
}
zap.S().Infof("model type : %v", modelType)
zap.S().Infof("model for image width : %v", imgWidth)
zap.S().Infof("model for image height: %v", imgHeight)

View File

@ -36,6 +36,15 @@ func Test_parseModelName(t *testing.T) {
wantHorizon: 10,
wantErr: false,
},
{
name: "linear with prefix",
args: args{modelPath: "/tmp/model_angle_linear_120x160h10.tflite"},
wantModelType: tools.ModelTypeLinear,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "bad-model",
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},