First implementation
This commit is contained in:
		
							
								
								
									
										211
									
								
								pkg/steering/steering.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										211
									
								
								pkg/steering/steering.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,211 @@
 | 
			
		||||
package steering
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/cyrilix/robocar-base/service"
 | 
			
		||||
	"github.com/cyrilix/robocar-protobuf/go/events"
 | 
			
		||||
	mqtt "github.com/eclipse/paho.mqtt.golang"
 | 
			
		||||
	"github.com/golang/protobuf/proto"
 | 
			
		||||
	"github.com/mattn/go-tflite"
 | 
			
		||||
	"github.com/mattn/go-tflite/delegates/edgetpu"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"image"
 | 
			
		||||
	"sort"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewPart(client mqtt.Client, steeringTopic, cameraTopic string, edgeVerbosity int) *Part {
 | 
			
		||||
	return &Part{
 | 
			
		||||
		client:        client,
 | 
			
		||||
		steeringTopic: steeringTopic,
 | 
			
		||||
		cameraTopic:   cameraTopic,
 | 
			
		||||
		edgeVebosity:  edgeVerbosity,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Part struct {
 | 
			
		||||
	client        mqtt.Client
 | 
			
		||||
	steeringTopic string
 | 
			
		||||
	cameraTopic   string
 | 
			
		||||
 | 
			
		||||
	cancel chan interface{}
 | 
			
		||||
 | 
			
		||||
	options      *tflite.InterpreterOptions
 | 
			
		||||
	interpreter  *tflite.Interpreter
 | 
			
		||||
	modelPath    string
 | 
			
		||||
	model        *tflite.Model
 | 
			
		||||
	edgeVebosity int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Part) Start() error {
 | 
			
		||||
	p.model = tflite.NewModelFromFile(p.modelPath)
 | 
			
		||||
	if p.model == nil {
 | 
			
		||||
		return fmt.Errorf("cannot load model %v", p.modelPath)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Get the list of devices
 | 
			
		||||
	devices, err := edgetpu.DeviceList()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("could not get EdgeTPU devices: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	if len(devices) == 0 {
 | 
			
		||||
		return fmt.Errorf("no edge TPU devices found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Print the EdgeTPU version
 | 
			
		||||
	edgetpuVersion, err := edgetpu.Version()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("cannot get EdgeTPU version: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	zap.S().Infof("EdgeTPU Version: %s\n", edgetpuVersion)
 | 
			
		||||
	edgetpu.Verbosity(p.edgeVebosity)
 | 
			
		||||
 | 
			
		||||
	p.options = tflite.NewInterpreterOptions()
 | 
			
		||||
	//options.SetNumThread(4)
 | 
			
		||||
	p.options.SetErrorReporter(func(msg string, userData interface{}) {
 | 
			
		||||
		zap.L().Error(msg,
 | 
			
		||||
			zap.Any("userData", userData),
 | 
			
		||||
		)
 | 
			
		||||
	}, nil)
 | 
			
		||||
 | 
			
		||||
	// Add the first EdgeTPU device
 | 
			
		||||
	d := edgetpu.New(devices[0])
 | 
			
		||||
	p.options.AddDelegate(d)
 | 
			
		||||
 | 
			
		||||
	p.interpreter = tflite.NewInterpreter(p.model, p.options)
 | 
			
		||||
	if p.interpreter == nil {
 | 
			
		||||
		return fmt.Errorf("cannot create interpreter")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := registerCallbacks(p); err != nil {
 | 
			
		||||
		zap.S().Errorf("unable to register callbacks: %v", err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.cancel = make(chan interface{})
 | 
			
		||||
	<-p.cancel
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Part) Stop() {
 | 
			
		||||
	close(p.cancel)
 | 
			
		||||
	service.StopService("steering", p.client, p.cameraTopic)
 | 
			
		||||
	if p.interpreter != nil {
 | 
			
		||||
		p.interpreter.Delete()
 | 
			
		||||
	}
 | 
			
		||||
	p.interpreter.Delete()
 | 
			
		||||
	if p.options != nil {
 | 
			
		||||
		p.options.Delete()
 | 
			
		||||
	}
 | 
			
		||||
	if p.model != nil {
 | 
			
		||||
		p.model.Delete()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Part) onFrame(_ mqtt.Client, message mqtt.Message) {
 | 
			
		||||
	var msg events.FrameMessage
 | 
			
		||||
	err := proto.Unmarshal(message.Payload(), &msg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Errorf("unable to unmarshal protobuf %T message: %v", &msg, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	img, _, err := image.Decode(bytes.NewReader(msg.GetFrame()))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.L().Error("unable to decode frame", zap.Error(err))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	steering, confidence, err := p.Value(img)
 | 
			
		||||
	msgSteering := &events.SteeringMessage{
 | 
			
		||||
		Steering:   steering,
 | 
			
		||||
		Confidence: confidence,
 | 
			
		||||
		FrameRef:   msg.Id,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	payload, err := proto.Marshal(msgSteering)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.L().Error("unable to marshal protobuf message", zap.Error(err))
 | 
			
		||||
	}
 | 
			
		||||
	publish(p.client, p.steeringTopic, payload)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Part) Value(img image.Image) (float32, float32, error) {
 | 
			
		||||
	status := p.interpreter.AllocateTensors()
 | 
			
		||||
	if status != tflite.OK {
 | 
			
		||||
		return 0., 0., fmt.Errorf("tensor allocate failed: %v", status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	input := p.interpreter.GetInputTensor(0)
 | 
			
		||||
 | 
			
		||||
	dx := img.Bounds().Dx()
 | 
			
		||||
	dy := img.Bounds().Dy()
 | 
			
		||||
 | 
			
		||||
	bb := make([]byte, dx*dy*3)
 | 
			
		||||
	for y := 0; y < 128; y++ {
 | 
			
		||||
		for x := 0; x < 160; x++ {
 | 
			
		||||
			r, g, b, _ := img.At(x, y).RGBA()
 | 
			
		||||
			bb[(y*dx+x)*3+0] = byte(float64(r) / 255.0)
 | 
			
		||||
			bb[(y*dx+x)*3+1] = byte(float64(g) / 255.0)
 | 
			
		||||
			bb[(y*dx+x)*3+2] = byte(float64(b) / 255.0)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	status = input.CopyFromBuffer(bb)
 | 
			
		||||
	if status != tflite.OK {
 | 
			
		||||
		return 0., 0., fmt.Errorf("input copy from buffer failed: %v", status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	status = p.interpreter.Invoke()
 | 
			
		||||
	if status != tflite.OK {
 | 
			
		||||
		return 0., 0., fmt.Errorf("invoke failed: %v", status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	output := p.interpreter.GetOutputTensor(0)
 | 
			
		||||
	outputSize := output.Dim(output.NumDims() - 1)
 | 
			
		||||
	b := make([]byte, outputSize)
 | 
			
		||||
	type result struct {
 | 
			
		||||
		score float64
 | 
			
		||||
		index int
 | 
			
		||||
	}
 | 
			
		||||
	status = output.CopyToBuffer(&b[0])
 | 
			
		||||
	if status != tflite.OK {
 | 
			
		||||
		return 0., 0., fmt.Errorf("output failed: %v", status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var results []result
 | 
			
		||||
	minScore := 0.2
 | 
			
		||||
	for i := 0; i < outputSize; i++ {
 | 
			
		||||
		score := float64(b[i]) / 255.0
 | 
			
		||||
		if score < minScore {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		results = append(results, result{score: score, index: i})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(results) == 0 {
 | 
			
		||||
		zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
 | 
			
		||||
		return 0., 0., nil
 | 
			
		||||
	}
 | 
			
		||||
	sort.Slice(results, func(i, j int) bool {
 | 
			
		||||
		return results[i].score > results[j].score
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	steering := float64(results[0].index)*(2./float64(outputSize)) - 1
 | 
			
		||||
	zap.L().Debug("found steering",
 | 
			
		||||
		zap.Float64("steering", steering),
 | 
			
		||||
		zap.Float64("score", results[0].score),
 | 
			
		||||
	)
 | 
			
		||||
	return float32(steering), float32(results[0].score), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var registerCallbacks = func(p *Part) error {
 | 
			
		||||
	err := service.RegisterCallback(p.client, p.cameraTopic, p.onFrame)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var publish = func(client mqtt.Client, topic string, payload []byte) {
 | 
			
		||||
	client.Publish(topic, 0, false, payload)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user