Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
b57698380e | |||
c1221d2e0e | |||
24d39461b4 | |||
2e8bf4252b | |||
bc90193edf | |||
acf6066fea |
@ -4,8 +4,8 @@ IMAGE_NAME=robocar-steering-tflite-edgetpu
|
|||||||
TAG=$(git describe)
|
TAG=$(git describe)
|
||||||
FULL_IMAGE_NAME=docker.io/cyrilix/${IMAGE_NAME}:${TAG}
|
FULL_IMAGE_NAME=docker.io/cyrilix/${IMAGE_NAME}:${TAG}
|
||||||
BINARY=rc-steering
|
BINARY=rc-steering
|
||||||
TFLITE_VERSION=2.6.0
|
TFLITE_VERSION=2.10.0
|
||||||
GOLANG_VERSION=1.18
|
GOLANG_VERSION=1.19
|
||||||
|
|
||||||
GOTAGS="-tags netgo"
|
GOTAGS="-tags netgo"
|
||||||
BUILDER_CONTAINER="${IMAGE_NAME}-builder"
|
BUILDER_CONTAINER="${IMAGE_NAME}-builder"
|
||||||
@ -35,6 +35,20 @@ image_build_binaries(){
|
|||||||
buildah run $containerName ln -s /usr/lib/aarch64-linux-gnu/libedgetpu.so.1 /usr/lib/aarch64-linux-gnu/libedgetpu.so
|
buildah run $containerName ln -s /usr/lib/aarch64-linux-gnu/libedgetpu.so.1 /usr/lib/aarch64-linux-gnu/libedgetpu.so
|
||||||
|
|
||||||
printf "Compile for linux/amd64\n"
|
printf "Compile for linux/amd64\n"
|
||||||
|
LIB_ARCH=x86_64-linux-gnu
|
||||||
|
LIB_FLAGS="-L /usr/local/lib/${LIB_ARCH} \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/base -labsl_base -labsl_throw_delegate -labsl_raw_logging_internal -labsl_spinlock_wait -labsl_malloc_internal -labsl_log_severity -labsl_strerror \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/status -labsl_status \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/hash -labsl_hash -labsl_city -labsl_low_level_hash \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/flags -labsl_flags -labsl_flags_internal -labsl_flags_marshalling -labsl_flags_reflection -labsl_flags_config -labsl_flags_program_name -labsl_flags_private_handle_accessor -labsl_flags_commandlineflag -labsl_flags_commandlineflag_internal\
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/types -labsl_bad_variant_access -labsl_bad_optional_access -labsl_bad_any_cast_impl \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/strings -labsl_strings -labsl_str_format_internal -labsl_cord -labsl_cordz_info -labsl_cord_internal -labsl_cordz_functions -labsl_cordz_handle -labsl_strings_internal \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/time -labsl_time -labsl_time_zone -labsl_civil_time \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/numeric -labsl_int128 \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/synchronization -labsl_synchronization -labsl_graphcycles_internal\
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/debugging -labsl_stacktrace -labsl_symbolize -labsl_debugging_internal -labsl_demangle_internal \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/profiling -labsl_exponential_biased \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/container -labsl_raw_hash_set -labsl_hashtablez_sampler"
|
||||||
buildah run \
|
buildah run \
|
||||||
--env CGO_ENABLED=1 \
|
--env CGO_ENABLED=1 \
|
||||||
--env CC=gcc \
|
--env CC=gcc \
|
||||||
@ -43,25 +57,27 @@ image_build_binaries(){
|
|||||||
--env GOARCH=amd64 \
|
--env GOARCH=amd64 \
|
||||||
--env GOARM=${GOARM} \
|
--env GOARM=${GOARM} \
|
||||||
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
||||||
--env CGO_LDFLAGS="-L /usr/local/lib/x86_64-linux-gnu -L /usr/lib/x86_64-linux-gnu" \
|
--env CGO_LDFLAGS="${LIB_FLAGS}" \
|
||||||
$containerName \
|
$containerName \
|
||||||
go build -a -o rc-steering.amd64 ./cmd/rc-steering
|
go build -a -o rc-steering.amd64 ./cmd/rc-steering
|
||||||
#--env CGO_CXXFLAGS="--std=c++1z" \
|
#--env CGO_CXXFLAGS="--std=c++1z" \
|
||||||
|
|
||||||
printf "Compile for linux/arm/v7\n"
|
|
||||||
buildah run \
|
|
||||||
--env CGO_ENABLED=1 \
|
|
||||||
--env CC=arm-linux-gnueabihf-gcc \
|
|
||||||
--env CXX=arm-linux-gnueabihf-g++ \
|
|
||||||
--env GOOS=linux \
|
|
||||||
--env GOARCH=arm \
|
|
||||||
--env GOARM=7 \
|
|
||||||
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
|
||||||
--env CGO_LDFLAGS="-L /usr/lib/arm-linux-gnueabihf -L /usr/local/lib/arm-linux-gnueabihf" \
|
|
||||||
$containerName \
|
|
||||||
go build -a -o rc-steering.armhf ./cmd/rc-steering
|
|
||||||
|
|
||||||
printf "Compile for linux/arm64\n"
|
printf "Compile for linux/arm64\n"
|
||||||
|
LIB_ARCH=aarch64-linux-gnu
|
||||||
|
LIB_FLAGS="-L /usr/local/lib/${LIB_ARCH} \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/base -labsl_base -labsl_throw_delegate -labsl_raw_logging_internal -labsl_spinlock_wait -labsl_malloc_internal -labsl_log_severity -labsl_strerror\
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/status -labsl_status \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/hash -labsl_hash -labsl_city -labsl_low_level_hash \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/flags -labsl_flags -labsl_flags_internal -labsl_flags_marshalling -labsl_flags_reflection -labsl_flags_config -labsl_flags_program_name -labsl_flags_private_handle_accessor -labsl_flags_commandlineflag -labsl_flags_commandlineflag_internal\
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/types -labsl_bad_variant_access -labsl_bad_optional_access -labsl_bad_any_cast_impl \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/strings -labsl_strings -labsl_str_format_internal -labsl_cord -labsl_cordz_info -labsl_cord_internal -labsl_cordz_functions -labsl_cordz_handle -labsl_strings_internal \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/time -labsl_time -labsl_time_zone -labsl_civil_time \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/numeric -labsl_int128 \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/synchronization -labsl_synchronization -labsl_graphcycles_internal\
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/debugging -labsl_stacktrace -labsl_symbolize -labsl_debugging_internal -labsl_demangle_internal \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/profiling -labsl_exponential_biased \
|
||||||
|
-L/usr/local/lib/${LIB_ARCH}/absl/container -labsl_raw_hash_set -labsl_hashtablez_sampler"
|
||||||
buildah run \
|
buildah run \
|
||||||
--env CGO_ENABLED=1 \
|
--env CGO_ENABLED=1 \
|
||||||
--env CC=aarch64-linux-gnu-gcc \
|
--env CC=aarch64-linux-gnu-gcc \
|
||||||
@ -69,7 +85,7 @@ image_build_binaries(){
|
|||||||
--env GOOS=linux \
|
--env GOOS=linux \
|
||||||
--env GOARCH=arm64 \
|
--env GOARCH=arm64 \
|
||||||
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
||||||
--env CGO_LDFLAGS="-L /usr/lib/aarch64-linux-gnu -L /usr/local/lib/aarch64-linux-gnu" \
|
--env CGO_LDFLAGS="${LIB_FLAGS}" \
|
||||||
$containerName \
|
$containerName \
|
||||||
go build -a -o rc-steering.arm64 ./cmd/rc-steering
|
go build -a -o rc-steering.arm64 ./cmd/rc-steering
|
||||||
}
|
}
|
||||||
@ -110,7 +126,6 @@ image_build_binaries
|
|||||||
|
|
||||||
image_build linux/amd64
|
image_build linux/amd64
|
||||||
image_build linux/arm64
|
image_build linux/arm64
|
||||||
image_build linux/arm/v7
|
|
||||||
|
|
||||||
|
|
||||||
# push image
|
# push image
|
||||||
|
@ -3,18 +3,26 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"github.com/cyrilix/robocar-base/cli"
|
"github.com/cyrilix/robocar-base/cli"
|
||||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
||||||
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultClientId = "robocar-steering-tflite-edgetpu"
|
DefaultClientId = "robocar-steering-tflite-edgetpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var mqttBroker, username, password, clientId string
|
var mqttBroker, username, password, clientId string
|
||||||
var cameraTopic, steeringTopic string
|
var cameraTopic, steeringTopic string
|
||||||
@ -31,8 +39,8 @@ func main() {
|
|||||||
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
||||||
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
||||||
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
||||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
|
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
|
||||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
|
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
|
||||||
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
||||||
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
@ -63,19 +71,37 @@ func main() {
|
|||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||||
|
}
|
||||||
|
if imgWidth == 0 {
|
||||||
|
imgWidth = width
|
||||||
|
}
|
||||||
|
if imgHeight == 0 {
|
||||||
|
imgHeight = height
|
||||||
|
}
|
||||||
|
if horizonFromName == 0 {
|
||||||
|
horizon = horizonFromName
|
||||||
|
}
|
||||||
if imgWidth <= 0 || imgHeight <= 0 {
|
if imgWidth <= 0 || imgHeight <= 0 {
|
||||||
zap.L().Error("img-width and img-height are mandatory")
|
zap.L().Error("img-width and img-height are mandatory")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
zap.S().Infof("model type : %v", modelType)
|
||||||
|
zap.S().Infof("model for image width : %v", imgWidth)
|
||||||
|
zap.S().Infof("model for image height: %v", imgHeight)
|
||||||
|
zap.S().Infof("model with horizon : %v", horizon)
|
||||||
|
|
||||||
client, err := cli.Connect(mqttBroker, username, password, clientId)
|
client, err := cli.Connect(mqttBroker, username, password, clientId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
|
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer client.Disconnect(50)
|
defer client.Disconnect(50)
|
||||||
|
|
||||||
p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
|
|
||||||
cli.HandleExit(p)
|
cli.HandleExit(p)
|
||||||
@ -85,3 +111,33 @@ func main() {
|
|||||||
zap.L().Fatal("unable to start service", zap.Error(err))
|
zap.L().Fatal("unable to start service", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
|
||||||
|
match := modelNameRegex.FindStringSubmatch(modelPath)
|
||||||
|
|
||||||
|
results := map[string]string{}
|
||||||
|
for i, name := range match {
|
||||||
|
results[modelNameRegex.SubexpNames()[i]] = name
|
||||||
|
}
|
||||||
|
modelType = tools.ParseModelType(results["type"])
|
||||||
|
if modelType == tools.ModelTypeUnknown {
|
||||||
|
err = fmt.Errorf("unknown model type '%v'", results["type"])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
imgWidth, err = strconv.Atoi(results["imgWidth"])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
imgHeight, err = strconv.Atoi(results["imgHeight"])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
horizon, err = strconv.Atoi(results["horizon"])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
97
cmd/rc-steering/rc-steering_test.go
Normal file
97
cmd/rc-steering/rc-steering_test.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseModelName(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
modelPath string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantModelType tools.ModelType
|
||||||
|
wantImgWidth int
|
||||||
|
wantImgHeight int
|
||||||
|
wantHorizon int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "categorical",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeCategorical,
|
||||||
|
wantImgWidth: 120,
|
||||||
|
wantImgHeight: 160,
|
||||||
|
wantHorizon: 10,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linear",
|
||||||
|
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeLinear,
|
||||||
|
wantImgWidth: 120,
|
||||||
|
wantImgHeight: 160,
|
||||||
|
wantHorizon: 10,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-model",
|
||||||
|
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-width",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-height",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad-horizon",
|
||||||
|
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
|
||||||
|
wantModelType: tools.ModelTypeUnknown,
|
||||||
|
wantImgWidth: 0,
|
||||||
|
wantImgHeight: 0,
|
||||||
|
wantHorizon: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotModelType != tt.wantModelType {
|
||||||
|
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
|
||||||
|
}
|
||||||
|
if gotImgWidth != tt.wantImgWidth {
|
||||||
|
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
|
||||||
|
}
|
||||||
|
if gotImgHeight != tt.wantImgHeight {
|
||||||
|
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
|
||||||
|
}
|
||||||
|
if gotHorizon != tt.wantHorizon {
|
||||||
|
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
4
go.mod
4
go.mod
@ -1,13 +1,13 @@
|
|||||||
module github.com/cyrilix/robocar-steering-tflite-edgetpu
|
module github.com/cyrilix/robocar-steering-tflite-edgetpu
|
||||||
|
|
||||||
go 1.18
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/cyrilix/robocar-base v0.1.7
|
github.com/cyrilix/robocar-base v0.1.7
|
||||||
github.com/cyrilix/robocar-protobuf/go v1.0.5
|
github.com/cyrilix/robocar-protobuf/go v1.0.5
|
||||||
github.com/disintegration/imaging v1.6.2
|
github.com/disintegration/imaging v1.6.2
|
||||||
github.com/eclipse/paho.mqtt.golang v1.4.1
|
github.com/eclipse/paho.mqtt.golang v1.4.1
|
||||||
github.com/mattn/go-tflite v1.0.2
|
github.com/mattn/go-tflite v1.0.4
|
||||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v0.30.0
|
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v0.30.0
|
||||||
go.opentelemetry.io/otel/metric v0.30.0
|
go.opentelemetry.io/otel/metric v0.30.0
|
||||||
go.opentelemetry.io/otel/sdk/metric v0.30.0
|
go.opentelemetry.io/otel/sdk/metric v0.30.0
|
||||||
|
5
go.sum
5
go.sum
@ -1,7 +1,6 @@
|
|||||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||||
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||||
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
|
|
||||||
github.com/cyrilix/robocar-base v0.1.7 h1:EVzZ0KjigSFpke5f3A/PybEH3WFUEIrYSc3z/dhOZ48=
|
github.com/cyrilix/robocar-base v0.1.7 h1:EVzZ0KjigSFpke5f3A/PybEH3WFUEIrYSc3z/dhOZ48=
|
||||||
github.com/cyrilix/robocar-base v0.1.7/go.mod h1:4E11HQSNy2NT8e7MW188y6ST9C0RzarKyn7sK/3V/Lk=
|
github.com/cyrilix/robocar-base v0.1.7/go.mod h1:4E11HQSNy2NT8e7MW188y6ST9C0RzarKyn7sK/3V/Lk=
|
||||||
github.com/cyrilix/robocar-protobuf/go v1.0.5 h1:PX1At+pf6G7gJwT4LzJLQu3/LPFTTNNlZmZSYtnSELY=
|
github.com/cyrilix/robocar-protobuf/go v1.0.5 h1:PX1At+pf6G7gJwT4LzJLQu3/LPFTTNNlZmZSYtnSELY=
|
||||||
@ -31,8 +30,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||||
github.com/mattn/go-tflite v1.0.2 h1:P9CKqjyRSRM31SfL65WklD8U5B/iPD4CJQiRkB8K02g=
|
github.com/mattn/go-tflite v1.0.4 h1:wpfNKjMr3IJz4xI+oUeHE70RU6Q5dZc0FK/X8vCWLAo=
|
||||||
github.com/mattn/go-tflite v1.0.2/go.mod h1:2NwhEYXoP8vxRIpu95DElqMkZoV39ABRPF3AETN7N1w=
|
github.com/mattn/go-tflite v1.0.4/go.mod h1:j7bVlVHgKURK0p7AQOw3OqlGE2SVXqck7JsJo4wI+bc=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
@ -19,9 +19,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||||
return &Part{
|
return &Part{
|
||||||
client: client,
|
client: client,
|
||||||
|
modelType: modelType,
|
||||||
modelPath: modelPath,
|
modelPath: modelPath,
|
||||||
steeringTopic: steeringTopic,
|
steeringTopic: steeringTopic,
|
||||||
cameraTopic: cameraTopic,
|
cameraTopic: cameraTopic,
|
||||||
@ -42,6 +43,7 @@ type Part struct {
|
|||||||
|
|
||||||
options *tflite.InterpreterOptions
|
options *tflite.InterpreterOptions
|
||||||
interpreter *tflite.Interpreter
|
interpreter *tflite.Interpreter
|
||||||
|
modelType tools.ModelType
|
||||||
modelPath string
|
modelPath string
|
||||||
model *tflite.Model
|
model *tflite.Model
|
||||||
edgeVebosity int
|
edgeVebosity int
|
||||||
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
|||||||
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
||||||
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
||||||
|
|
||||||
steering, score := tools.LinearBin(output, 15, -1, 2.0)
|
var steering, score float64
|
||||||
|
switch p.modelType {
|
||||||
|
case tools.ModelTypeCategorical:
|
||||||
|
steering, score = tools.LinearBin(output, 15, -1, 2.0)
|
||||||
|
case tools.ModelTypeLinear:
|
||||||
|
steering = 2*(float64(output[0])/255.) - 1.
|
||||||
|
score = 0.6
|
||||||
|
}
|
||||||
zap.L().Debug("found steering",
|
zap.L().Debug("found steering",
|
||||||
zap.Float64("steering", steering),
|
zap.Float64("steering", steering),
|
||||||
zap.Float64("score", score),
|
zap.Float64("score", score),
|
||||||
|
@ -1,9 +1,39 @@
|
|||||||
package tools
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelType int
|
||||||
|
|
||||||
|
func ParseModelType(s string) ModelType {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "categorical":
|
||||||
|
return ModelTypeCategorical
|
||||||
|
case "linear":
|
||||||
|
return ModelTypeLinear
|
||||||
|
default:
|
||||||
|
return ModelTypeUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ModelType) String() string {
|
||||||
|
switch m {
|
||||||
|
case ModelTypeCategorical:
|
||||||
|
return "categorical"
|
||||||
|
case ModelTypeLinear:
|
||||||
|
return "linear"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelTypeUnknown ModelType = iota
|
||||||
|
ModelTypeCategorical
|
||||||
|
ModelTypeLinear
|
||||||
)
|
)
|
||||||
|
|
||||||
// LinearBin perform inverse linear_bin, taking
|
// LinearBin perform inverse linear_bin, taking
|
||||||
@ -15,19 +45,11 @@ func LinearBin(arr []uint8, n int, offset int, r float64) (float64, float64) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var results []result
|
var results []result
|
||||||
minScore := 0.2
|
|
||||||
for i := 0; i < outputSize; i++ {
|
for i := 0; i < outputSize; i++ {
|
||||||
score := float64(int(arr[i])) / 255.0
|
score := float64(int(arr[i])) / 255.0
|
||||||
if score < minScore {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
results = append(results, result{score: score, index: i})
|
results = append(results, result{score: score, index: i})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(results) == 0 {
|
|
||||||
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
|
|
||||||
return 0., 0.
|
|
||||||
}
|
|
||||||
zap.S().Debugf("raw result: %v", results)
|
zap.S().Debugf("raw result: %v", results)
|
||||||
|
|
||||||
sort.Slice(results, func(i, j int) bool {
|
sort.Slice(results, func(i, j int) bool {
|
||||||
|
@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseModelType(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
s string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want ModelType
|
||||||
|
}{
|
||||||
|
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
|
||||||
|
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
|
||||||
|
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
|
||||||
|
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
|
||||||
|
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
|
||||||
|
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ParseModelType(tt.args.s); got != tt.want {
|
||||||
|
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
2
vendor/github.com/mattn/go-tflite/README.md
generated
vendored
2
vendor/github.com/mattn/go-tflite/README.md
generated
vendored
@ -92,5 +92,5 @@ rm -Rf edgetpu
|
|||||||
MIT
|
MIT
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
Yasuhrio Matsumoto (a.k.a. mattn)
|
Yasuhiro Matsumoto (a.k.a. mattn)
|
||||||
|
|
||||||
|
1
vendor/github.com/mattn/go-tflite/tflite_experimental.go
generated
vendored
1
vendor/github.com/mattn/go-tflite/tflite_experimental.go
generated
vendored
@ -32,7 +32,6 @@ _make_registration(void* o_init, void* o_free, void* o_prepare, void* o_invoke,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void look_context(TfLiteContext *context) {
|
static void look_context(TfLiteContext *context) {
|
||||||
context->tensors;
|
|
||||||
TfLiteIntArray *plan = NULL;
|
TfLiteIntArray *plan = NULL;
|
||||||
context->GetExecutionPlan(context, &plan);
|
context->GetExecutionPlan(context, &plan);
|
||||||
if (plan == NULL) return;
|
if (plan == NULL) return;
|
||||||
|
2
vendor/modules.txt
vendored
2
vendor/modules.txt
vendored
@ -28,7 +28,7 @@ github.com/gorilla/websocket
|
|||||||
# github.com/mattn/go-pointer v0.0.1
|
# github.com/mattn/go-pointer v0.0.1
|
||||||
## explicit
|
## explicit
|
||||||
github.com/mattn/go-pointer
|
github.com/mattn/go-pointer
|
||||||
# github.com/mattn/go-tflite v1.0.2
|
# github.com/mattn/go-tflite v1.0.4
|
||||||
## explicit; go 1.13
|
## explicit; go 1.13
|
||||||
github.com/mattn/go-tflite
|
github.com/mattn/go-tflite
|
||||||
github.com/mattn/go-tflite/delegates
|
github.com/mattn/go-tflite/delegates
|
||||||
|
Reference in New Issue
Block a user