robocar-steering-tflite-edg.../vendor/github.com/mattn/go-tflite/tflite.go

259 lines
6.1 KiB
Go

package tflite
/*
#ifndef GO_TFLITE_H
#include "tflite.go.h"
#endif
#cgo LDFLAGS: -ltensorflowlite_c
#cgo android LDFLAGS: -ldl
#cgo linux,!android LDFLAGS: -ldl -lrt
*/
import "C"
import (
"reflect"
"unsafe"
"github.com/mattn/go-pointer"
"github.com/mattn/go-tflite/delegates"
)
//go:generate stringer -type TensorType,Status -output type_string.go .
// Model is TfLiteModel.
type Model struct {
m *C.TfLiteModel
}
// NewModel create new Model from buffer.
func NewModel(model_data []byte) *Model {
m := C.TfLiteModelCreate(C.CBytes(model_data), C.size_t(len(model_data)))
if m == nil {
return nil
}
return &Model{m: m}
}
// NewModelFromFile create new Model from file data.
func NewModelFromFile(model_path string) *Model {
ptr := C.CString(model_path)
defer C.free(unsafe.Pointer(ptr))
m := C.TfLiteModelCreateFromFile(ptr)
if m == nil {
return nil
}
return &Model{m: m}
}
// Delete delete instance of model.
func (m *Model) Delete() {
if m != nil {
C.TfLiteModelDelete(m.m)
}
}
// InterpreterOptions implement TfLiteInterpreterOptions.
type InterpreterOptions struct {
o *C.TfLiteInterpreterOptions
}
// NewInterpreterOptions create new InterpreterOptions.
func NewInterpreterOptions() *InterpreterOptions {
o := C.TfLiteInterpreterOptionsCreate()
if o == nil {
return nil
}
return &InterpreterOptions{o: o}
}
// SetNumThread set number of threads.
func (o *InterpreterOptions) SetNumThread(num_threads int) {
C.TfLiteInterpreterOptionsSetNumThreads(o.o, C.int32_t(num_threads))
}
// SetErrorRepoter set a function of reporter.
func (o *InterpreterOptions) SetErrorReporter(f func(string, interface{}), user_data interface{}) {
C._TfLiteInterpreterOptionsSetErrorReporter(o.o, pointer.Save(&callbackInfo{
user_data: user_data,
f: f,
}))
}
func (o *InterpreterOptions) AddDelegate(d delegates.Delegater) {
C.TfLiteInterpreterOptionsAddDelegate(o.o, (*C.TfLiteDelegate)(d.Ptr()))
}
// Delete delete instance of InterpreterOptions.
func (o *InterpreterOptions) Delete() {
if o != nil {
C.TfLiteInterpreterOptionsDelete(o.o)
}
}
// Interpreter implement TfLiteInterpreter.
type Interpreter struct {
i *C.TfLiteInterpreter
}
// NewInterpreter create new Interpreter.
func NewInterpreter(model *Model, options *InterpreterOptions) *Interpreter {
var o *C.TfLiteInterpreterOptions
if options != nil {
o = options.o
}
i := C.TfLiteInterpreterCreate(model.m, o)
if i == nil {
return nil
}
return &Interpreter{i: i}
}
// Delete delete instance of Interpreter.
func (i *Interpreter) Delete() {
if i != nil {
C.TfLiteInterpreterDelete(i.i)
}
}
// Tensor implement TfLiteTensor.
type Tensor struct {
t *C.TfLiteTensor
}
// GetInputTensorCount return number of input tensors.
func (i *Interpreter) GetInputTensorCount() int {
return int(C.TfLiteInterpreterGetInputTensorCount(i.i))
}
// GetInputTensor return input tensor specified by index.
func (i *Interpreter) GetInputTensor(index int) *Tensor {
t := C.TfLiteInterpreterGetInputTensor(i.i, C.int32_t(index))
if t == nil {
return nil
}
return &Tensor{t: t}
}
// State implement TfLiteStatus.
type Status int
const (
OK Status = 0
Error
)
// ResizeInputTensor resize the tensor specified by index with dims.
func (i *Interpreter) ResizeInputTensor(index int, dims []int32) Status {
s := C.TfLiteInterpreterResizeInputTensor(i.i, C.int32_t(index), (*C.int32_t)(unsafe.Pointer(&dims[0])), C.int32_t(len(dims)))
return Status(s)
}
// AllocateTensor allocate tensors for the interpreter.
func (i *Interpreter) AllocateTensors() Status {
if i != nil {
s := C.TfLiteInterpreterAllocateTensors(i.i)
return Status(s)
}
return Error
}
// Invoke invoke the task.
func (i *Interpreter) Invoke() Status {
s := C.TfLiteInterpreterInvoke(i.i)
return Status(s)
}
// GetOutputTensorCount return number of output tensors.
func (i *Interpreter) GetOutputTensorCount() int {
return int(C.TfLiteInterpreterGetOutputTensorCount(i.i))
}
// GetOutputTensor return output tensor specified by index.
func (i *Interpreter) GetOutputTensor(index int) *Tensor {
t := C.TfLiteInterpreterGetOutputTensor(i.i, C.int32_t(index))
if t == nil {
return nil
}
return &Tensor{t: t}
}
// TensorType is types of the tensor.
type TensorType int
const (
NoType TensorType = 0
Float32 TensorType = 1
Int32 TensorType = 2
UInt8 TensorType = 3
Int64 TensorType = 4
String TensorType = 5
Bool TensorType = 6
Int16 TensorType = 7
Complex64 TensorType = 8
Int8 TensorType = 9
)
// Type return TensorType.
func (t *Tensor) Type() TensorType {
return TensorType(C.TfLiteTensorType(t.t))
}
// NumDims return number of dimensions.
func (t *Tensor) NumDims() int {
return int(C.TfLiteTensorNumDims(t.t))
}
// Dim return dimension of the element specified by index.
func (t *Tensor) Dim(index int) int {
return int(C.TfLiteTensorDim(t.t, C.int32_t(index)))
}
// Shape return shape of the tensor.
func (t *Tensor) Shape() []int {
shape := make([]int, t.NumDims())
for i := 0; i < t.NumDims(); i++ {
shape[i] = t.Dim(i)
}
return shape
}
// ByteSize return byte size of the tensor.
func (t *Tensor) ByteSize() uint {
return uint(C.TfLiteTensorByteSize(t.t))
}
// Data return pointer of buffer.
func (t *Tensor) Data() unsafe.Pointer {
return C.TfLiteTensorData(t.t)
}
// Name return name of the tensor.
func (t *Tensor) Name() string {
return C.GoString(C.TfLiteTensorName(t.t))
}
// QuantizationParams implement TfLiteQuantizationParams.
type QuantizationParams struct {
Scale float64
ZeroPoint int
}
// QuantizationParams return quantization parameters of the tensor.
func (t *Tensor) QuantizationParams() QuantizationParams {
q := C.TfLiteTensorQuantizationParams(t.t)
return QuantizationParams{
Scale: float64(q.scale),
ZeroPoint: int(q.zero_point),
}
}
// CopyFromBuffer write buffer to the tensor.
func (t *Tensor) CopyFromBuffer(b interface{}) Status {
return Status(C.TfLiteTensorCopyFromBuffer(t.t, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize())))
}
// CopyToBuffer write buffer from the tensor.
func (t *Tensor) CopyToBuffer(b interface{}) Status {
return Status(C.TfLiteTensorCopyToBuffer(t.t, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize())))
}