test: add unit tests and fix error
This commit is contained in:
parent
9b0b772786
commit
0c5e8e93ac
@ -53,7 +53,7 @@ ignore-paths=
|
|||||||
# Files or directories matching the regular expression patterns are skipped.
|
# Files or directories matching the regular expression patterns are skipped.
|
||||||
# The regex matches against base names, not paths. The default value ignores
|
# The regex matches against base names, not paths. The default value ignores
|
||||||
# Emacs file locks
|
# Emacs file locks
|
||||||
ignore-patterns=^\.#
|
ignore-patterns=^\.#,test_.*?py
|
||||||
|
|
||||||
# List of module names for which member attributes should not be checked
|
# List of module names for which member attributes should not be checked
|
||||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||||
|
@ -29,12 +29,11 @@ class ObjectProcessor:
|
|||||||
self._objects_topic = objects_topic
|
self._objects_topic = objects_topic
|
||||||
self._objects_threshold = objects_threshold
|
self._objects_threshold = objects_threshold
|
||||||
|
|
||||||
def process(self, in_nn: dai.NNData, frame_ref, frame_date: datetime.datetime) -> None:
|
def process(self, in_nn: dai.NNData, frame_ref) -> None:
|
||||||
"""
|
"""
|
||||||
Parse and publish result of NeuralNetwork result
|
Parse and publish result of NeuralNetwork result
|
||||||
:param in_nn: NeuralNetwork result read from device
|
:param in_nn: NeuralNetwork result read from device
|
||||||
:param frame_ref: Id of the frame where objects are been detected
|
:param frame_ref: Id of the frame where objects are been detected
|
||||||
:param frame_date: Datetime of the frame used for detection
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
|
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
|
||||||
@ -45,9 +44,10 @@ class ObjectProcessor:
|
|||||||
scores = detection_scores[mask]
|
scores = detection_scores[mask]
|
||||||
|
|
||||||
if boxes.shape[0] > 0:
|
if boxes.shape[0] > 0:
|
||||||
self._publish_objects(boxes, frame_ref, frame_date, scores)
|
self._publish_objects(boxes, frame_ref, scores)
|
||||||
|
|
||||||
|
def _publish_objects(self, boxes: np.array, frame_ref, scores: np.array) -> None:
|
||||||
|
|
||||||
def _publish_objects(self, boxes: np.array, frame_ref, now: datetime.datetime, scores: np.array) -> None:
|
|
||||||
objects_msg = events.events_pb2.ObjectsMessage()
|
objects_msg = events.events_pb2.ObjectsMessage()
|
||||||
objs = []
|
objs = []
|
||||||
for i in range(boxes.shape[0]):
|
for i in range(boxes.shape[0]):
|
||||||
@ -56,7 +56,7 @@ class ObjectProcessor:
|
|||||||
objects_msg.objects.extend(objs)
|
objects_msg.objects.extend(objs)
|
||||||
objects_msg.frame_ref.name = frame_ref.name
|
objects_msg.frame_ref.name = frame_ref.name
|
||||||
objects_msg.frame_ref.id = frame_ref.id
|
objects_msg.frame_ref.id = frame_ref.id
|
||||||
objects_msg.frame_ref.created_at.FromDatetime(now)
|
objects_msg.frame_ref.created_at.FromDatetime(frame_ref.created_at.ToDatetime())
|
||||||
logger.debug("publish object event to %s", self._objects_topic)
|
logger.debug("publish object event to %s", self._objects_topic)
|
||||||
self._mqtt_client.publish(topic=self._objects_topic,
|
self._mqtt_client.publish(topic=self._objects_topic,
|
||||||
payload=objects_msg.SerializeToString(),
|
payload=objects_msg.SerializeToString(),
|
||||||
@ -64,6 +64,21 @@ class ObjectProcessor:
|
|||||||
retain=False)
|
retain=False)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameProcessError(Exception):
|
||||||
|
"""
|
||||||
|
Error base for invalid frame processing
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message -- explanation of the error
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
"""
|
||||||
|
:param message: explanation of the error
|
||||||
|
"""
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
class FrameProcessor:
|
class FrameProcessor:
|
||||||
"""
|
"""
|
||||||
Processor for camera frames
|
Processor for camera frames
|
||||||
@ -73,16 +88,19 @@ class FrameProcessor:
|
|||||||
self._mqtt_client = mqtt_client
|
self._mqtt_client = mqtt_client
|
||||||
self._frame_topic = frame_topic
|
self._frame_topic = frame_topic
|
||||||
|
|
||||||
def process(self, img: dai.ImgFrame) -> (typing.Any, datetime.datetime):
|
def process(self, img: dai.ImgFrame) -> typing.Any:
|
||||||
"""
|
"""
|
||||||
Publish camera frames
|
Publish camera frames
|
||||||
:param img:
|
:param img:
|
||||||
:return:
|
:return:
|
||||||
id frame
|
id frame reference
|
||||||
frame creation datetime
|
:raise:
|
||||||
|
FrameProcessError if frame can't be processed
|
||||||
"""
|
"""
|
||||||
im_resize = img.getCvFrame()
|
im_resize = img.getCvFrame()
|
||||||
is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
|
is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
|
||||||
|
if not is_success:
|
||||||
|
raise FrameProcessError("unable to process to encode frame to jpg")
|
||||||
byte_im = im_buf_arr.tobytes()
|
byte_im = im_buf_arr.tobytes()
|
||||||
|
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
@ -96,7 +114,7 @@ class FrameProcessor:
|
|||||||
payload=frame_msg.SerializeToString(),
|
payload=frame_msg.SerializeToString(),
|
||||||
qos=0,
|
qos=0,
|
||||||
retain=False)
|
retain=False)
|
||||||
return frame_msg.id, now
|
return frame_msg.id
|
||||||
|
|
||||||
|
|
||||||
class PipelineController:
|
class PipelineController:
|
||||||
@ -124,7 +142,7 @@ class PipelineController:
|
|||||||
|
|
||||||
# Resize image
|
# Resize image
|
||||||
manip = pipeline.create(dai.node.ImageManip)
|
manip = pipeline.create(dai.node.ImageManip)
|
||||||
manip.initialConfig.setResize(NN_WIDTH, NN_HEIGHT)
|
manip.initialConfig.setResize(_NN_WIDTH, _NN_HEIGHT)
|
||||||
manip.initialConfig.setFrameType(dai.ImgFrame.Type.RGB888p)
|
manip.initialConfig.setFrameType(dai.ImgFrame.Type.RGB888p)
|
||||||
manip.initialConfig.setKeepAspectRatio(False)
|
manip.initialConfig.setKeepAspectRatio(False)
|
||||||
|
|
||||||
@ -161,7 +179,7 @@ class PipelineController:
|
|||||||
def _configure_detection_nn(pipeline: dai.Pipeline) -> dai.node.NeuralNetwork:
|
def _configure_detection_nn(pipeline: dai.Pipeline) -> dai.node.NeuralNetwork:
|
||||||
# Define a neural network that will make predictions based on the source frames
|
# Define a neural network that will make predictions based on the source frames
|
||||||
detection_nn = pipeline.create(dai.node.NeuralNetwork)
|
detection_nn = pipeline.create(dai.node.NeuralNetwork)
|
||||||
detection_nn.setBlobPath(NN_PATH)
|
detection_nn.setBlobPath(_NN_PATH)
|
||||||
detection_nn.setNumPoolFrames(4)
|
detection_nn.setNumPoolFrames(4)
|
||||||
detection_nn.input.setBlocking(False)
|
detection_nn.input.setBlocking(False)
|
||||||
detection_nn.setNumInferenceThreads(2)
|
detection_nn.setNumInferenceThreads(2)
|
||||||
@ -201,11 +219,13 @@ class PipelineController:
|
|||||||
|
|
||||||
# Wait for frame
|
# Wait for frame
|
||||||
in_rgb: dai.ImgFrame = q_rgb.get() # blocking call, will wait until a new data has arrived
|
in_rgb: dai.ImgFrame = q_rgb.get() # blocking call, will wait until a new data has arrived
|
||||||
frame_msg, now = self._frame_processor.process(in_rgb)
|
try:
|
||||||
|
frame_ref = self._frame_processor.process(in_rgb)
|
||||||
|
except FrameProcessError as ex:
|
||||||
|
logger.error("unable to process frame: %s", str(ex))
|
||||||
# Read NN result
|
# Read NN result
|
||||||
in_nn: dai.NNData = q_nn.get()
|
in_nn: dai.NNData = q_nn.get()
|
||||||
self._object_processor.process(in_nn, frame_msg.id, now)
|
self._object_processor.process(in_nn, frame_ref)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""
|
"""
|
||||||
@ -214,6 +234,7 @@ class PipelineController:
|
|||||||
"""
|
"""
|
||||||
self._stop = True
|
self._stop = True
|
||||||
|
|
||||||
|
|
||||||
def _bbox_to_object(bbox: np.array, score: float) -> events.events_pb2.Object:
|
def _bbox_to_object(bbox: np.array, score: float) -> events.events_pb2.Object:
|
||||||
obj = events.events_pb2.Object()
|
obj = events.events_pb2.Object()
|
||||||
obj.type = events.events_pb2.TypeObject.ANY
|
obj.type = events.events_pb2.TypeObject.ANY
|
||||||
|
0
camera/tests/__init__.py
Normal file
0
camera/tests/__init__.py
Normal file
150
camera/tests/test_depthai.py
Normal file
150
camera/tests/test_depthai.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import datetime
|
||||||
|
import unittest.mock
|
||||||
|
|
||||||
|
import depthai as dai
|
||||||
|
import numpy as np
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
|
||||||
|
import camera.depthai
|
||||||
|
import events.events_pb2
|
||||||
|
|
||||||
|
Object = dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mqtt_client(mocker: pytest_mock.MockerFixture) -> mqtt.Client:
|
||||||
|
return mocker.MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
class TestObjectProcessor:
|
||||||
|
@pytest.fixture
|
||||||
|
def frame_ref(self):
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
frame_msg = events.events_pb2.FrameMessage()
|
||||||
|
frame_msg.id.name = "robocar-oak-camera-oak"
|
||||||
|
frame_msg.id.id = str(int(now.timestamp() * 1000))
|
||||||
|
frame_msg.id.created_at.FromDatetime(now)
|
||||||
|
return frame_msg.id
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def object1(self) -> Object:
|
||||||
|
return {
|
||||||
|
"left": 0.3,
|
||||||
|
"right": 0.7,
|
||||||
|
"top": 0.1,
|
||||||
|
"bottom": 0.6,
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def raw_objects_empty(self, mocker: pytest_mock.MockerFixture) -> dai.NNData:
|
||||||
|
raw_objects = mocker.MagicMock()
|
||||||
|
|
||||||
|
def mock_return(name):
|
||||||
|
if name == "ExpandDims":
|
||||||
|
return [[0] * 4] * 100
|
||||||
|
elif name == "ExpandDims_2":
|
||||||
|
return [0] * 100
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name} is not a valid arg")
|
||||||
|
|
||||||
|
m = mocker.patch(target='depthai.NNData.getLayerFp16', autospec=True)
|
||||||
|
m.getLayerFp16 = mock_return
|
||||||
|
return m
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def raw_objects_one(self, mocker: pytest_mock.MockerFixture, object1: Object) -> dai.NNData:
|
||||||
|
def mock_return(name):
|
||||||
|
if name == "ExpandDims": # Detection boxes
|
||||||
|
boxes = [[0] * 4] * 100
|
||||||
|
boxes[0] = [object1["top"], object1["left"], object1["bottom"], object1["right"]]
|
||||||
|
return np.array(boxes)
|
||||||
|
|
||||||
|
elif name == "ExpandDims_2": # Detection scores
|
||||||
|
scores = [0] * 100
|
||||||
|
scores[0] = object1["score"]
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name} is not a valid arg")
|
||||||
|
|
||||||
|
m = mocker.patch(target='depthai.NNData.getLayerFp16', autospec=True)
|
||||||
|
m.getLayerFp16 = mock_return
|
||||||
|
return m
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def object_processor(self, mqtt_client: mqtt.Client) -> camera.depthai.ObjectProcessor:
|
||||||
|
return camera.depthai.ObjectProcessor(mqtt_client, "topic/object", 0.2)
|
||||||
|
|
||||||
|
def test_process_without_object(self, object_processor: camera.depthai.ObjectProcessor, mqtt_client,
|
||||||
|
raw_objects_empty, frame_ref):
|
||||||
|
object_processor.process(raw_objects_empty, frame_ref)
|
||||||
|
mqtt_client.publish.assert_not_called()
|
||||||
|
|
||||||
|
def test_process_with_object_with_low_score(self, object_processor: camera.depthai.ObjectProcessor, mqtt_client,
|
||||||
|
raw_objects_one, frame_ref):
|
||||||
|
object_processor._objects_threshold = 0.9
|
||||||
|
object_processor.process(raw_objects_one, frame_ref)
|
||||||
|
mqtt_client.publish.assert_not_called()
|
||||||
|
|
||||||
|
def test_process_with_one_object(self,
|
||||||
|
object_processor: camera.depthai.ObjectProcessor, mqtt_client,
|
||||||
|
raw_objects_one, frame_ref, object1: Object):
|
||||||
|
object_processor.process(raw_objects_one, frame_ref)
|
||||||
|
left = object1["left"]
|
||||||
|
right = object1["right"]
|
||||||
|
top = object1["top"]
|
||||||
|
bottom = object1["bottom"]
|
||||||
|
score = object1["score"]
|
||||||
|
|
||||||
|
pub_mock: unittest.mock.MagicMock = mqtt_client.publish
|
||||||
|
pub_mock.assert_called_once_with(payload=unittest.mock.ANY, qos=0, retain=False, topic="topic/object")
|
||||||
|
payload = pub_mock.call_args.kwargs['payload']
|
||||||
|
objects_msg = events.events_pb2.ObjectsMessage()
|
||||||
|
objects_msg.ParseFromString(payload)
|
||||||
|
assert len(objects_msg.objects) == 1
|
||||||
|
assert left - 0.0001 < objects_msg.objects[0].left < left + 0.0001
|
||||||
|
assert right - 0.0001 < objects_msg.objects[0].right < right + 0.0001
|
||||||
|
assert top - 0.0001 < objects_msg.objects[0].top < top + 0.0001
|
||||||
|
assert bottom - 0.0001 < objects_msg.objects[0].bottom < bottom + 0.0001
|
||||||
|
assert score - 0.0001 < objects_msg.objects[0].confidence < score + 0.0001
|
||||||
|
assert objects_msg.frame_ref == frame_ref
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrameProcessor:
|
||||||
|
@pytest.fixture
|
||||||
|
def frame_processor(self, mqtt_client: mqtt.Client) -> camera.depthai.FrameProcessor:
|
||||||
|
return camera.depthai.FrameProcessor(mqtt_client, "topic/frame")
|
||||||
|
|
||||||
|
def test_process(self, frame_processor: camera.depthai.FrameProcessor, mocker: pytest_mock.MockerFixture,
|
||||||
|
mqtt_client: mqtt.Client):
|
||||||
|
img: dai.ImgFrame = mocker.MagicMock()
|
||||||
|
mocker.patch(target="cv2.imencode").return_value = (True, np.array(b"img content"))
|
||||||
|
|
||||||
|
frame_ref = frame_processor.process(img)
|
||||||
|
|
||||||
|
pub_mock: unittest.mock.MagicMock = mqtt_client.publish
|
||||||
|
pub_mock.assert_called_once_with(payload=unittest.mock.ANY, qos=0, retain=False, topic="topic/frame")
|
||||||
|
payload = pub_mock.call_args.kwargs['payload']
|
||||||
|
frame_msg = events.events_pb2.FrameMessage()
|
||||||
|
frame_msg.ParseFromString(payload)
|
||||||
|
|
||||||
|
assert frame_msg.id == frame_ref
|
||||||
|
assert frame_msg.frame == b"img content"
|
||||||
|
|
||||||
|
assert frame_msg.id.name == "robocar-oak-camera-oak"
|
||||||
|
assert len(frame_msg.id.id) is 13
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
assert now - datetime.timedelta(
|
||||||
|
milliseconds=10) < frame_msg.id.created_at.ToDatetime() < now + datetime.timedelta(milliseconds=10)
|
||||||
|
|
||||||
|
def test_process_error(self, frame_processor: camera.depthai.FrameProcessor, mocker: pytest_mock.MockerFixture,
|
||||||
|
mqtt_client: mqtt.Client):
|
||||||
|
img: dai.ImgFrame = mocker.MagicMock()
|
||||||
|
mocker.patch(target="cv2.imencode").return_value = (False, None)
|
||||||
|
|
||||||
|
with pytest.raises(camera.depthai.FrameProcessError) as ex:
|
||||||
|
_ = frame_processor.process(img)
|
||||||
|
exception_raised = ex.value
|
||||||
|
assert exception_raised.message == "unable to process to encode frame to jpg"
|
20
poetry.lock
generated
20
poetry.lock
generated
@ -345,6 +345,20 @@ tomli = ">=1.0.0"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
|
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-mock"
|
||||||
|
version = "3.10.0"
|
||||||
|
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pytest = ">=5.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.8.2"
|
version = "2.8.2"
|
||||||
@ -452,7 +466,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "c241f1088945e1b451684386b6aac16ba85a85394c94c8a19099b5ebea05b53f"
|
content-hash = "d062fb11b00c63a20b69560322be0723850a7fe3a4363bfe0339f1f75ffd0e2e"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
astroid = [
|
astroid = [
|
||||||
@ -679,6 +693,10 @@ pytest = [
|
|||||||
{file = "pytest-7.1.3-py3-none-any.whl", hash = "sha256:1377bda3466d70b55e3f5cecfa55bb7cfcf219c7964629b967c37cf0bda818b7"},
|
{file = "pytest-7.1.3-py3-none-any.whl", hash = "sha256:1377bda3466d70b55e3f5cecfa55bb7cfcf219c7964629b967c37cf0bda818b7"},
|
||||||
{file = "pytest-7.1.3.tar.gz", hash = "sha256:4f365fec2dff9c1162f834d9f18af1ba13062db0c708bf7b946f8a5c76180c39"},
|
{file = "pytest-7.1.3.tar.gz", hash = "sha256:4f365fec2dff9c1162f834d9f18af1ba13062db0c708bf7b946f8a5c76180c39"},
|
||||||
]
|
]
|
||||||
|
pytest-mock = [
|
||||||
|
{file = "pytest-mock-3.10.0.tar.gz", hash = "sha256:fbbdb085ef7c252a326fd8cdcac0aa3b1333d8811f131bdcc701002e1be7ed4f"},
|
||||||
|
{file = "pytest_mock-3.10.0-py3-none-any.whl", hash = "sha256:f4c973eeae0282963eb293eb173ce91b091a79c1334455acfac9ddee8a1c784b"},
|
||||||
|
]
|
||||||
python-dateutil = [
|
python-dateutil = [
|
||||||
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
||||||
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
||||||
|
@ -22,6 +22,7 @@ protobuf = "^4.21.8"
|
|||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
pytest = "^7.1.3"
|
pytest = "^7.1.3"
|
||||||
|
pytest-mock = "^3.10.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
Loading…
Reference in New Issue
Block a user