test: add unit tests and fix error
This commit is contained in:
		@@ -29,12 +29,11 @@ class ObjectProcessor:
 | 
			
		||||
        self._objects_topic = objects_topic
 | 
			
		||||
        self._objects_threshold = objects_threshold
 | 
			
		||||
 | 
			
		||||
    def process(self, in_nn: dai.NNData, frame_ref, frame_date: datetime.datetime) -> None:
 | 
			
		||||
    def process(self, in_nn: dai.NNData, frame_ref) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Parse and publish result of NeuralNetwork result
 | 
			
		||||
        :param in_nn: NeuralNetwork result read from device
 | 
			
		||||
        :param frame_ref: Id of the frame where objects are been detected
 | 
			
		||||
        :param frame_date: Datetime of the frame used for detection
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
 | 
			
		||||
@@ -45,9 +44,10 @@ class ObjectProcessor:
 | 
			
		||||
        scores = detection_scores[mask]
 | 
			
		||||
 | 
			
		||||
        if boxes.shape[0] > 0:
 | 
			
		||||
            self._publish_objects(boxes, frame_ref, frame_date, scores)
 | 
			
		||||
            self._publish_objects(boxes, frame_ref, scores)
 | 
			
		||||
 | 
			
		||||
    def _publish_objects(self, boxes: np.array, frame_ref, scores: np.array) -> None:
 | 
			
		||||
 | 
			
		||||
    def _publish_objects(self, boxes: np.array, frame_ref, now: datetime.datetime, scores: np.array) -> None:
 | 
			
		||||
        objects_msg = events.events_pb2.ObjectsMessage()
 | 
			
		||||
        objs = []
 | 
			
		||||
        for i in range(boxes.shape[0]):
 | 
			
		||||
@@ -56,7 +56,7 @@ class ObjectProcessor:
 | 
			
		||||
        objects_msg.objects.extend(objs)
 | 
			
		||||
        objects_msg.frame_ref.name = frame_ref.name
 | 
			
		||||
        objects_msg.frame_ref.id = frame_ref.id
 | 
			
		||||
        objects_msg.frame_ref.created_at.FromDatetime(now)
 | 
			
		||||
        objects_msg.frame_ref.created_at.FromDatetime(frame_ref.created_at.ToDatetime())
 | 
			
		||||
        logger.debug("publish object event to %s", self._objects_topic)
 | 
			
		||||
        self._mqtt_client.publish(topic=self._objects_topic,
 | 
			
		||||
                                  payload=objects_msg.SerializeToString(),
 | 
			
		||||
@@ -64,6 +64,21 @@ class ObjectProcessor:
 | 
			
		||||
                                  retain=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FrameProcessError(Exception):
 | 
			
		||||
    """
 | 
			
		||||
    Error base for invalid frame processing
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        message -- explanation of the error
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message: str):
 | 
			
		||||
        """
 | 
			
		||||
        :param message: explanation of the error
 | 
			
		||||
        """
 | 
			
		||||
        self.message = message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FrameProcessor:
 | 
			
		||||
    """
 | 
			
		||||
    Processor for camera frames
 | 
			
		||||
@@ -73,16 +88,19 @@ class FrameProcessor:
 | 
			
		||||
        self._mqtt_client = mqtt_client
 | 
			
		||||
        self._frame_topic = frame_topic
 | 
			
		||||
 | 
			
		||||
    def process(self, img: dai.ImgFrame) -> (typing.Any, datetime.datetime):
 | 
			
		||||
    def process(self, img: dai.ImgFrame) -> typing.Any:
 | 
			
		||||
        """
 | 
			
		||||
        Publish camera frames
 | 
			
		||||
        :param img:
 | 
			
		||||
        :return:
 | 
			
		||||
            id frame
 | 
			
		||||
            frame creation datetime
 | 
			
		||||
            id frame reference
 | 
			
		||||
        :raise:
 | 
			
		||||
            FrameProcessError if frame can't be processed
 | 
			
		||||
        """
 | 
			
		||||
        im_resize = img.getCvFrame()
 | 
			
		||||
        is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
 | 
			
		||||
        if not is_success:
 | 
			
		||||
            raise FrameProcessError("unable to process to encode frame to jpg")
 | 
			
		||||
        byte_im = im_buf_arr.tobytes()
 | 
			
		||||
 | 
			
		||||
        now = datetime.datetime.now()
 | 
			
		||||
@@ -96,7 +114,7 @@ class FrameProcessor:
 | 
			
		||||
                                  payload=frame_msg.SerializeToString(),
 | 
			
		||||
                                  qos=0,
 | 
			
		||||
                                  retain=False)
 | 
			
		||||
        return frame_msg.id, now
 | 
			
		||||
        return frame_msg.id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipelineController:
 | 
			
		||||
@@ -124,7 +142,7 @@ class PipelineController:
 | 
			
		||||
 | 
			
		||||
        # Resize image
 | 
			
		||||
        manip = pipeline.create(dai.node.ImageManip)
 | 
			
		||||
        manip.initialConfig.setResize(NN_WIDTH, NN_HEIGHT)
 | 
			
		||||
        manip.initialConfig.setResize(_NN_WIDTH, _NN_HEIGHT)
 | 
			
		||||
        manip.initialConfig.setFrameType(dai.ImgFrame.Type.RGB888p)
 | 
			
		||||
        manip.initialConfig.setKeepAspectRatio(False)
 | 
			
		||||
 | 
			
		||||
@@ -161,7 +179,7 @@ class PipelineController:
 | 
			
		||||
    def _configure_detection_nn(pipeline: dai.Pipeline) -> dai.node.NeuralNetwork:
 | 
			
		||||
        # Define a neural network that will make predictions based on the source frames
 | 
			
		||||
        detection_nn = pipeline.create(dai.node.NeuralNetwork)
 | 
			
		||||
        detection_nn.setBlobPath(NN_PATH)
 | 
			
		||||
        detection_nn.setBlobPath(_NN_PATH)
 | 
			
		||||
        detection_nn.setNumPoolFrames(4)
 | 
			
		||||
        detection_nn.input.setBlocking(False)
 | 
			
		||||
        detection_nn.setNumInferenceThreads(2)
 | 
			
		||||
@@ -201,11 +219,13 @@ class PipelineController:
 | 
			
		||||
 | 
			
		||||
        # Wait for frame
 | 
			
		||||
        in_rgb: dai.ImgFrame = q_rgb.get()  # blocking call, will wait until a new data has arrived
 | 
			
		||||
        frame_msg, now = self._frame_processor.process(in_rgb)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            frame_ref = self._frame_processor.process(in_rgb)
 | 
			
		||||
        except FrameProcessError as ex:
 | 
			
		||||
            logger.error("unable to process frame: %s", str(ex))
 | 
			
		||||
        # Read NN result
 | 
			
		||||
        in_nn: dai.NNData = q_nn.get()
 | 
			
		||||
        self._object_processor.process(in_nn, frame_msg.id, now)
 | 
			
		||||
        self._object_processor.process(in_nn, frame_ref)
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        """
 | 
			
		||||
@@ -214,6 +234,7 @@ class PipelineController:
 | 
			
		||||
        """
 | 
			
		||||
        self._stop = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bbox_to_object(bbox: np.array, score: float) -> events.events_pb2.Object:
 | 
			
		||||
    obj = events.events_pb2.Object()
 | 
			
		||||
    obj.type = events.events_pb2.TypeObject.ANY
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										0
									
								
								camera/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								camera/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										150
									
								
								camera/tests/test_depthai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								camera/tests/test_depthai.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,150 @@
 | 
			
		||||
import datetime
 | 
			
		||||
import unittest.mock
 | 
			
		||||
 | 
			
		||||
import depthai as dai
 | 
			
		||||
import numpy as np
 | 
			
		||||
import paho.mqtt.client as mqtt
 | 
			
		||||
import pytest
 | 
			
		||||
import pytest_mock
 | 
			
		||||
 | 
			
		||||
import camera.depthai
 | 
			
		||||
import events.events_pb2
 | 
			
		||||
 | 
			
		||||
Object = dict[str, float]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def mqtt_client(mocker: pytest_mock.MockerFixture) -> mqtt.Client:
 | 
			
		||||
    return mocker.MagicMock()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestObjectProcessor:
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def frame_ref(self):
 | 
			
		||||
        now = datetime.datetime.now()
 | 
			
		||||
        frame_msg = events.events_pb2.FrameMessage()
 | 
			
		||||
        frame_msg.id.name = "robocar-oak-camera-oak"
 | 
			
		||||
        frame_msg.id.id = str(int(now.timestamp() * 1000))
 | 
			
		||||
        frame_msg.id.created_at.FromDatetime(now)
 | 
			
		||||
        return frame_msg.id
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def object1(self) -> Object:
 | 
			
		||||
        return {
 | 
			
		||||
            "left": 0.3,
 | 
			
		||||
            "right": 0.7,
 | 
			
		||||
            "top": 0.1,
 | 
			
		||||
            "bottom": 0.6,
 | 
			
		||||
            "score": 0.8,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def raw_objects_empty(self, mocker: pytest_mock.MockerFixture) -> dai.NNData:
 | 
			
		||||
        raw_objects = mocker.MagicMock()
 | 
			
		||||
 | 
			
		||||
        def mock_return(name):
 | 
			
		||||
            if name == "ExpandDims":
 | 
			
		||||
                return [[0] * 4] * 100
 | 
			
		||||
            elif name == "ExpandDims_2":
 | 
			
		||||
                return [0] * 100
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"{name} is not a valid arg")
 | 
			
		||||
 | 
			
		||||
        m = mocker.patch(target='depthai.NNData.getLayerFp16', autospec=True)
 | 
			
		||||
        m.getLayerFp16 = mock_return
 | 
			
		||||
        return m
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def raw_objects_one(self, mocker: pytest_mock.MockerFixture, object1: Object) -> dai.NNData:
 | 
			
		||||
        def mock_return(name):
 | 
			
		||||
            if name == "ExpandDims":  # Detection boxes
 | 
			
		||||
                boxes = [[0] * 4] * 100
 | 
			
		||||
                boxes[0] = [object1["top"], object1["left"], object1["bottom"], object1["right"]]
 | 
			
		||||
                return np.array(boxes)
 | 
			
		||||
 | 
			
		||||
            elif name == "ExpandDims_2":  # Detection scores
 | 
			
		||||
                scores = [0] * 100
 | 
			
		||||
                scores[0] = object1["score"]
 | 
			
		||||
                return scores
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"{name} is not a valid arg")
 | 
			
		||||
 | 
			
		||||
        m = mocker.patch(target='depthai.NNData.getLayerFp16', autospec=True)
 | 
			
		||||
        m.getLayerFp16 = mock_return
 | 
			
		||||
        return m
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def object_processor(self, mqtt_client: mqtt.Client) -> camera.depthai.ObjectProcessor:
 | 
			
		||||
        return camera.depthai.ObjectProcessor(mqtt_client, "topic/object", 0.2)
 | 
			
		||||
 | 
			
		||||
    def test_process_without_object(self, object_processor: camera.depthai.ObjectProcessor, mqtt_client,
 | 
			
		||||
                                    raw_objects_empty, frame_ref):
 | 
			
		||||
        object_processor.process(raw_objects_empty, frame_ref)
 | 
			
		||||
        mqtt_client.publish.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    def test_process_with_object_with_low_score(self, object_processor: camera.depthai.ObjectProcessor, mqtt_client,
 | 
			
		||||
                                                raw_objects_one, frame_ref):
 | 
			
		||||
        object_processor._objects_threshold = 0.9
 | 
			
		||||
        object_processor.process(raw_objects_one, frame_ref)
 | 
			
		||||
        mqtt_client.publish.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    def test_process_with_one_object(self,
 | 
			
		||||
                                     object_processor: camera.depthai.ObjectProcessor, mqtt_client,
 | 
			
		||||
                                     raw_objects_one, frame_ref, object1: Object):
 | 
			
		||||
        object_processor.process(raw_objects_one, frame_ref)
 | 
			
		||||
        left = object1["left"]
 | 
			
		||||
        right = object1["right"]
 | 
			
		||||
        top = object1["top"]
 | 
			
		||||
        bottom = object1["bottom"]
 | 
			
		||||
        score = object1["score"]
 | 
			
		||||
 | 
			
		||||
        pub_mock: unittest.mock.MagicMock = mqtt_client.publish
 | 
			
		||||
        pub_mock.assert_called_once_with(payload=unittest.mock.ANY, qos=0, retain=False, topic="topic/object")
 | 
			
		||||
        payload = pub_mock.call_args.kwargs['payload']
 | 
			
		||||
        objects_msg = events.events_pb2.ObjectsMessage()
 | 
			
		||||
        objects_msg.ParseFromString(payload)
 | 
			
		||||
        assert len(objects_msg.objects) == 1
 | 
			
		||||
        assert left - 0.0001 < objects_msg.objects[0].left < left + 0.0001
 | 
			
		||||
        assert right - 0.0001 < objects_msg.objects[0].right < right + 0.0001
 | 
			
		||||
        assert top - 0.0001 < objects_msg.objects[0].top < top + 0.0001
 | 
			
		||||
        assert bottom - 0.0001 < objects_msg.objects[0].bottom < bottom + 0.0001
 | 
			
		||||
        assert score - 0.0001 < objects_msg.objects[0].confidence < score + 0.0001
 | 
			
		||||
        assert objects_msg.frame_ref == frame_ref
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestFrameProcessor:
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def frame_processor(self, mqtt_client: mqtt.Client) -> camera.depthai.FrameProcessor:
 | 
			
		||||
        return camera.depthai.FrameProcessor(mqtt_client, "topic/frame")
 | 
			
		||||
 | 
			
		||||
    def test_process(self, frame_processor: camera.depthai.FrameProcessor, mocker: pytest_mock.MockerFixture,
 | 
			
		||||
                     mqtt_client: mqtt.Client):
 | 
			
		||||
        img: dai.ImgFrame = mocker.MagicMock()
 | 
			
		||||
        mocker.patch(target="cv2.imencode").return_value = (True, np.array(b"img content"))
 | 
			
		||||
 | 
			
		||||
        frame_ref = frame_processor.process(img)
 | 
			
		||||
 | 
			
		||||
        pub_mock: unittest.mock.MagicMock = mqtt_client.publish
 | 
			
		||||
        pub_mock.assert_called_once_with(payload=unittest.mock.ANY, qos=0, retain=False, topic="topic/frame")
 | 
			
		||||
        payload = pub_mock.call_args.kwargs['payload']
 | 
			
		||||
        frame_msg = events.events_pb2.FrameMessage()
 | 
			
		||||
        frame_msg.ParseFromString(payload)
 | 
			
		||||
 | 
			
		||||
        assert frame_msg.id == frame_ref
 | 
			
		||||
        assert frame_msg.frame == b"img content"
 | 
			
		||||
 | 
			
		||||
        assert frame_msg.id.name == "robocar-oak-camera-oak"
 | 
			
		||||
        assert len(frame_msg.id.id) is 13
 | 
			
		||||
        now = datetime.datetime.now()
 | 
			
		||||
        assert now - datetime.timedelta(
 | 
			
		||||
            milliseconds=10) < frame_msg.id.created_at.ToDatetime() < now + datetime.timedelta(milliseconds=10)
 | 
			
		||||
 | 
			
		||||
    def test_process_error(self, frame_processor: camera.depthai.FrameProcessor, mocker: pytest_mock.MockerFixture,
 | 
			
		||||
                           mqtt_client: mqtt.Client):
 | 
			
		||||
        img: dai.ImgFrame = mocker.MagicMock()
 | 
			
		||||
        mocker.patch(target="cv2.imencode").return_value = (False, None)
 | 
			
		||||
 | 
			
		||||
        with pytest.raises(camera.depthai.FrameProcessError) as ex:
 | 
			
		||||
            _ = frame_processor.process(img)
 | 
			
		||||
        exception_raised = ex.value
 | 
			
		||||
        assert exception_raised.message == "unable to process to encode frame to jpg"
 | 
			
		||||
		Reference in New Issue
	
	Block a user