feat(objects detection): implement object detection
see https://github.com/luxonis/depthai-experiments/tree/master/gen2-mobile-object-localizer
This commit is contained in:
		@@ -2,8 +2,10 @@
 | 
			
		||||
Publish data from oak-lite device
 | 
			
		||||
 | 
			
		||||
Usage: rc-oak-camera [-u USERNAME | --mqtt-username=USERNAME] [--mqtt-password=PASSWORD] [--mqtt-broker=HOSTNAME] \
 | 
			
		||||
    [--mqtt-topic-robocar-oak-camera="TOPIC_CAMERA"] [--mqtt-client-id=CLIENT_ID] \
 | 
			
		||||
    [-H IMG_HEIGHT | --image-height=IMG_HEIGHT] [-W IMG_WIDTH | --image-width=IMG_width]
 | 
			
		||||
    [--mqtt-topic-robocar-oak-camera="TOPIC_CAMERA"] [--mqtt-topic-robocar-objects="TOPIC_OBJECTS"] \
 | 
			
		||||
    [--mqtt-client-id=CLIENT_ID] \
 | 
			
		||||
    [-H IMG_HEIGHT | --image-height=IMG_HEIGHT] [-W IMG_WIDTH | --image-width=IMG_width] \
 | 
			
		||||
    [-t OBJECTS_THRESHOLD | --objects-threshold=OBJECTS_THRESHOLD]
 | 
			
		||||
 | 
			
		||||
Options:
 | 
			
		||||
-h --help                                               Show this screen.
 | 
			
		||||
@@ -12,8 +14,10 @@ Options:
 | 
			
		||||
-b HOSTNAME --mqtt-broker=HOSTNAME                      MQTT broker host
 | 
			
		||||
-C CLIENT_ID --mqtt-client-id=CLIENT_ID                 MQTT client id
 | 
			
		||||
-c TOPIC_CAMERA --mqtt-topic-robocar-oak-camera=TOPIC_CAMERA        MQTT topic where to publish robocar-oak-camera frames
 | 
			
		||||
-o TOPIC_OBJECTS --mqtt-topic-robocar-objects=TOPIC_OBJECTS         MQTT topic where to publish objects detection results
 | 
			
		||||
-H IMG_HEIGHT --image-height=IMG_HEIGHT                 IMG_HEIGHT image height
 | 
			
		||||
-W IMG_WIDTH --image-width=IMG_width                    IMG_WIDTH image width
 | 
			
		||||
-t OBJECTS_THRESHOLD --objects-threshold=OBJECTS_THRESHOLD    OBJECTS_THRESHOLD threshold to filter objects detected
 | 
			
		||||
"""
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
@@ -50,9 +54,13 @@ def execute_from_command_line():
 | 
			
		||||
                                                          default_client_id),
 | 
			
		||||
                              )
 | 
			
		||||
    frame_topic = get_default_value(args["--mqtt-topic-robocar-oak-camera"], "MQTT_TOPIC_CAMERA", "/oak/camera_rgb")
 | 
			
		||||
    objects_topic = get_default_value(args["--mqtt-topic-robocar-objects"], "MQTT_TOPIC_OBJECTS", "/objects")
 | 
			
		||||
 | 
			
		||||
    frame_processor = cam.FramePublisher(mqtt_client=client,
 | 
			
		||||
                                         frame_topic=frame_topic,
 | 
			
		||||
                                         objects_topic=objects_topic,
 | 
			
		||||
                                         objects_threshold=float(get_default_value(args["--objects-threshold"],
 | 
			
		||||
                                                                                   "OBJECTS_THRESHOLD", 0.2)),
 | 
			
		||||
                                         img_width=int(get_default_value(args["--image-width"], "IMAGE_WIDTH", 160)),
 | 
			
		||||
                                         img_height=int(get_default_value(args["--image-height"], "IMAGE_HEIGHT", 120)))
 | 
			
		||||
    frame_processor.run()
 | 
			
		||||
 
 | 
			
		||||
@@ -6,14 +6,22 @@ import events.events_pb2
 | 
			
		||||
 | 
			
		||||
import depthai as dai
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
NN_PATH = "/models/mobile_object_localizer_192x192_openvino_2021.4_6shave.blob"
 | 
			
		||||
NN_WIDTH = 192
 | 
			
		||||
NN_HEIGHT = 192
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FramePublisher:
 | 
			
		||||
    def __init__(self, mqtt_client: mqtt.Client, frame_topic: str, img_width: int, img_height: int):
 | 
			
		||||
    def __init__(self, mqtt_client: mqtt.Client, frame_topic: str, objects_topic: str, objects_threshold: float,
 | 
			
		||||
                 img_width: int, img_height: int):
 | 
			
		||||
        self._mqtt_client = mqtt_client
 | 
			
		||||
        self._frame_topic = frame_topic
 | 
			
		||||
        self._objects_topic = objects_topic
 | 
			
		||||
        self._objects_threshold = objects_threshold
 | 
			
		||||
        self._img_width = img_width
 | 
			
		||||
        self._img_height = img_height
 | 
			
		||||
        self._pipeline = self._configure_pipeline()
 | 
			
		||||
@@ -22,11 +30,30 @@ class FramePublisher:
 | 
			
		||||
        logger.info("configure pipeline")
 | 
			
		||||
        pipeline = dai.Pipeline()
 | 
			
		||||
 | 
			
		||||
        pipeline.setOpenVINOVersion(version=dai.OpenVINO.VERSION_2021_4)
 | 
			
		||||
 | 
			
		||||
        # Define a neural network that will make predictions based on the source frames
 | 
			
		||||
        detection_nn = pipeline.create(dai.node.NeuralNetwork)
 | 
			
		||||
        detection_nn.setBlobPath(NN_PATH)
 | 
			
		||||
        detection_nn.setNumPoolFrames(4)
 | 
			
		||||
        detection_nn.input.setBlocking(False)
 | 
			
		||||
        detection_nn.setNumInferenceThreads(2)
 | 
			
		||||
 | 
			
		||||
        xout_nn = pipeline.create(dai.node.XLinkOut)
 | 
			
		||||
        xout_nn.setStreamName("nn")
 | 
			
		||||
        xout_nn.input.setBlocking(False)
 | 
			
		||||
 | 
			
		||||
        # Resize image
 | 
			
		||||
        manip = pipeline.create(dai.node.ImageManip)
 | 
			
		||||
        manip.initialConfig.setResize(NN_WIDTH, NN_HEIGHT)
 | 
			
		||||
        manip.initialConfig.setFrameType(dai.ImgFrame.Type.RGB888p)
 | 
			
		||||
        manip.initialConfig.setKeepAspectRatio(False)
 | 
			
		||||
 | 
			
		||||
        cam_rgb = pipeline.create(dai.node.ColorCamera)
 | 
			
		||||
        xout_rgb = pipeline.create(dai.node.XLinkOut)
 | 
			
		||||
 | 
			
		||||
        xout_rgb.setStreamName("rgb")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # Properties
 | 
			
		||||
        cam_rgb.setBoardSocket(dai.CameraBoardSocket.RGB)
 | 
			
		||||
        cam_rgb.setPreviewSize(width=self._img_width, height=self._img_height)
 | 
			
		||||
@@ -34,8 +61,14 @@ class FramePublisher:
 | 
			
		||||
        cam_rgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.RGB)
 | 
			
		||||
        cam_rgb.setFps(30)
 | 
			
		||||
 | 
			
		||||
        # Linking
 | 
			
		||||
        # Link preview to manip and manip to nn
 | 
			
		||||
        cam_rgb.preview.link(manip.inputImage)
 | 
			
		||||
        manip.out.link(detection_nn.input)
 | 
			
		||||
 | 
			
		||||
        # Linking to output
 | 
			
		||||
        cam_rgb.preview.link(xout_rgb.input)
 | 
			
		||||
        detection_nn.out.link(xout_nn.input)
 | 
			
		||||
 | 
			
		||||
        logger.info("pipeline configured")
 | 
			
		||||
        return pipeline
 | 
			
		||||
 | 
			
		||||
@@ -51,7 +84,8 @@ class FramePublisher:
 | 
			
		||||
            device.startPipeline()
 | 
			
		||||
            # Queues
 | 
			
		||||
            queue_size = 4
 | 
			
		||||
            q_rgb = device.getOutputQueue("rgb", maxSize=queue_size, blocking=False)
 | 
			
		||||
            q_rgb = device.getOutputQueue(name="rgb", maxSize=queue_size, blocking=False)
 | 
			
		||||
            q_nn = device.getOutputQueue(name="nn", maxSize=queue_size, blocking=False)
 | 
			
		||||
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
@@ -76,5 +110,42 @@ class FramePublisher:
 | 
			
		||||
                                              qos=0,
 | 
			
		||||
                                              retain=False)
 | 
			
		||||
 | 
			
		||||
                    in_nn = q_nn.get()
 | 
			
		||||
 | 
			
		||||
                    # get outputs
 | 
			
		||||
                    detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
 | 
			
		||||
                    detection_scores = np.array(in_nn.getLayerFp16("ExpandDims_2")).reshape((100,))
 | 
			
		||||
 | 
			
		||||
                    # keep boxes bigger than threshold
 | 
			
		||||
                    mask = detection_scores >= self._objects_threshold
 | 
			
		||||
                    boxes = detection_boxes[mask]
 | 
			
		||||
                    scores = detection_scores[mask]
 | 
			
		||||
 | 
			
		||||
                    if boxes.shape[0] > 0:
 | 
			
		||||
                        objects_msg = events.events_pb2.ObjectsMessage()
 | 
			
		||||
                        objs = []
 | 
			
		||||
                        for i in range(boxes.shape[0]):
 | 
			
		||||
                            bbox = boxes[i]
 | 
			
		||||
                            logger.debug("new object detected: %s", str(bbox))
 | 
			
		||||
                            o = events.events_pb2.Object()
 | 
			
		||||
                            o.type = events.events_pb2.TypeObject.ANY
 | 
			
		||||
                            o.top = bbox[0].astype(float)
 | 
			
		||||
                            o.right = bbox[1].astype(float)
 | 
			
		||||
                            o.bottom = bbox[2].astype(float)
 | 
			
		||||
                            o.left = bbox[3].astype(float)
 | 
			
		||||
                            o.confidence = scores[i].astype(float)
 | 
			
		||||
                            objs.append(o)
 | 
			
		||||
                        objects_msg.objects.extend(objs)
 | 
			
		||||
 | 
			
		||||
                        objects_msg.frame_ref.name = frame_msg.id.name
 | 
			
		||||
                        objects_msg.frame_ref.id = frame_msg.id.id
 | 
			
		||||
                        objects_msg.frame_ref.created_at.FromDatetime(now)
 | 
			
		||||
 | 
			
		||||
                        logger.debug("publish object event to %s", self._frame_topic)
 | 
			
		||||
                        self._mqtt_client.publish(topic=self._objects_topic,
 | 
			
		||||
                                                  payload=objects_msg.SerializeToString(),
 | 
			
		||||
                                                  qos=0,
 | 
			
		||||
                                                  retain=False)
 | 
			
		||||
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    logger.exception("unexpected error: %s", str(e))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user