feat(objects detection): implement object detection
see https://github.com/luxonis/depthai-experiments/tree/master/gen2-mobile-object-localizer
This commit is contained in:
@ -2,8 +2,10 @@
|
||||
Publish data from oak-lite device
|
||||
|
||||
Usage: rc-oak-camera [-u USERNAME | --mqtt-username=USERNAME] [--mqtt-password=PASSWORD] [--mqtt-broker=HOSTNAME] \
|
||||
[--mqtt-topic-robocar-oak-camera="TOPIC_CAMERA"] [--mqtt-client-id=CLIENT_ID] \
|
||||
[-H IMG_HEIGHT | --image-height=IMG_HEIGHT] [-W IMG_WIDTH | --image-width=IMG_width]
|
||||
[--mqtt-topic-robocar-oak-camera="TOPIC_CAMERA"] [--mqtt-topic-robocar-objects="TOPIC_OBJECTS"] \
|
||||
[--mqtt-client-id=CLIENT_ID] \
|
||||
[-H IMG_HEIGHT | --image-height=IMG_HEIGHT] [-W IMG_WIDTH | --image-width=IMG_width] \
|
||||
[-t OBJECTS_THRESHOLD | --objects-threshold=OBJECTS_THRESHOLD]
|
||||
|
||||
Options:
|
||||
-h --help Show this screen.
|
||||
@ -12,8 +14,10 @@ Options:
|
||||
-b HOSTNAME --mqtt-broker=HOSTNAME MQTT broker host
|
||||
-C CLIENT_ID --mqtt-client-id=CLIENT_ID MQTT client id
|
||||
-c TOPIC_CAMERA --mqtt-topic-robocar-oak-camera=TOPIC_CAMERA MQTT topic where to publish robocar-oak-camera frames
|
||||
-o TOPIC_OBJECTS --mqtt-topic-robocar-objects=TOPIC_OBJECTS MQTT topic where to publish objects detection results
|
||||
-H IMG_HEIGHT --image-height=IMG_HEIGHT IMG_HEIGHT image height
|
||||
-W IMG_WIDTH --image-width=IMG_width IMG_WIDTH image width
|
||||
-t OBJECTS_THRESHOLD --objects-threshold=OBJECTS_THRESHOLD OBJECTS_THRESHOLD threshold to filter objects detected
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
@ -50,9 +54,13 @@ def execute_from_command_line():
|
||||
default_client_id),
|
||||
)
|
||||
frame_topic = get_default_value(args["--mqtt-topic-robocar-oak-camera"], "MQTT_TOPIC_CAMERA", "/oak/camera_rgb")
|
||||
objects_topic = get_default_value(args["--mqtt-topic-robocar-objects"], "MQTT_TOPIC_OBJECTS", "/objects")
|
||||
|
||||
frame_processor = cam.FramePublisher(mqtt_client=client,
|
||||
frame_topic=frame_topic,
|
||||
objects_topic=objects_topic,
|
||||
objects_threshold=float(get_default_value(args["--objects-threshold"],
|
||||
"OBJECTS_THRESHOLD", 0.2)),
|
||||
img_width=int(get_default_value(args["--image-width"], "IMAGE_WIDTH", 160)),
|
||||
img_height=int(get_default_value(args["--image-height"], "IMAGE_HEIGHT", 120)))
|
||||
frame_processor.run()
|
||||
|
@ -6,14 +6,22 @@ import events.events_pb2
|
||||
|
||||
import depthai as dai
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NN_PATH = "/models/mobile_object_localizer_192x192_openvino_2021.4_6shave.blob"
|
||||
NN_WIDTH = 192
|
||||
NN_HEIGHT = 192
|
||||
|
||||
|
||||
class FramePublisher:
|
||||
def __init__(self, mqtt_client: mqtt.Client, frame_topic: str, img_width: int, img_height: int):
|
||||
def __init__(self, mqtt_client: mqtt.Client, frame_topic: str, objects_topic: str, objects_threshold: float,
|
||||
img_width: int, img_height: int):
|
||||
self._mqtt_client = mqtt_client
|
||||
self._frame_topic = frame_topic
|
||||
self._objects_topic = objects_topic
|
||||
self._objects_threshold = objects_threshold
|
||||
self._img_width = img_width
|
||||
self._img_height = img_height
|
||||
self._pipeline = self._configure_pipeline()
|
||||
@ -22,11 +30,30 @@ class FramePublisher:
|
||||
logger.info("configure pipeline")
|
||||
pipeline = dai.Pipeline()
|
||||
|
||||
pipeline.setOpenVINOVersion(version=dai.OpenVINO.VERSION_2021_4)
|
||||
|
||||
# Define a neural network that will make predictions based on the source frames
|
||||
detection_nn = pipeline.create(dai.node.NeuralNetwork)
|
||||
detection_nn.setBlobPath(NN_PATH)
|
||||
detection_nn.setNumPoolFrames(4)
|
||||
detection_nn.input.setBlocking(False)
|
||||
detection_nn.setNumInferenceThreads(2)
|
||||
|
||||
xout_nn = pipeline.create(dai.node.XLinkOut)
|
||||
xout_nn.setStreamName("nn")
|
||||
xout_nn.input.setBlocking(False)
|
||||
|
||||
# Resize image
|
||||
manip = pipeline.create(dai.node.ImageManip)
|
||||
manip.initialConfig.setResize(NN_WIDTH, NN_HEIGHT)
|
||||
manip.initialConfig.setFrameType(dai.ImgFrame.Type.RGB888p)
|
||||
manip.initialConfig.setKeepAspectRatio(False)
|
||||
|
||||
cam_rgb = pipeline.create(dai.node.ColorCamera)
|
||||
xout_rgb = pipeline.create(dai.node.XLinkOut)
|
||||
|
||||
xout_rgb.setStreamName("rgb")
|
||||
|
||||
|
||||
# Properties
|
||||
cam_rgb.setBoardSocket(dai.CameraBoardSocket.RGB)
|
||||
cam_rgb.setPreviewSize(width=self._img_width, height=self._img_height)
|
||||
@ -34,8 +61,14 @@ class FramePublisher:
|
||||
cam_rgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.RGB)
|
||||
cam_rgb.setFps(30)
|
||||
|
||||
# Linking
|
||||
# Link preview to manip and manip to nn
|
||||
cam_rgb.preview.link(manip.inputImage)
|
||||
manip.out.link(detection_nn.input)
|
||||
|
||||
# Linking to output
|
||||
cam_rgb.preview.link(xout_rgb.input)
|
||||
detection_nn.out.link(xout_nn.input)
|
||||
|
||||
logger.info("pipeline configured")
|
||||
return pipeline
|
||||
|
||||
@ -51,7 +84,8 @@ class FramePublisher:
|
||||
device.startPipeline()
|
||||
# Queues
|
||||
queue_size = 4
|
||||
q_rgb = device.getOutputQueue("rgb", maxSize=queue_size, blocking=False)
|
||||
q_rgb = device.getOutputQueue(name="rgb", maxSize=queue_size, blocking=False)
|
||||
q_nn = device.getOutputQueue(name="nn", maxSize=queue_size, blocking=False)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -76,5 +110,42 @@ class FramePublisher:
|
||||
qos=0,
|
||||
retain=False)
|
||||
|
||||
in_nn = q_nn.get()
|
||||
|
||||
# get outputs
|
||||
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
|
||||
detection_scores = np.array(in_nn.getLayerFp16("ExpandDims_2")).reshape((100,))
|
||||
|
||||
# keep boxes bigger than threshold
|
||||
mask = detection_scores >= self._objects_threshold
|
||||
boxes = detection_boxes[mask]
|
||||
scores = detection_scores[mask]
|
||||
|
||||
if boxes.shape[0] > 0:
|
||||
objects_msg = events.events_pb2.ObjectsMessage()
|
||||
objs = []
|
||||
for i in range(boxes.shape[0]):
|
||||
bbox = boxes[i]
|
||||
logger.debug("new object detected: %s", str(bbox))
|
||||
o = events.events_pb2.Object()
|
||||
o.type = events.events_pb2.TypeObject.ANY
|
||||
o.top = bbox[0].astype(float)
|
||||
o.right = bbox[1].astype(float)
|
||||
o.bottom = bbox[2].astype(float)
|
||||
o.left = bbox[3].astype(float)
|
||||
o.confidence = scores[i].astype(float)
|
||||
objs.append(o)
|
||||
objects_msg.objects.extend(objs)
|
||||
|
||||
objects_msg.frame_ref.name = frame_msg.id.name
|
||||
objects_msg.frame_ref.id = frame_msg.id.id
|
||||
objects_msg.frame_ref.created_at.FromDatetime(now)
|
||||
|
||||
logger.debug("publish object event to %s", self._frame_topic)
|
||||
self._mqtt_client.publish(topic=self._objects_topic,
|
||||
payload=objects_msg.SerializeToString(),
|
||||
qos=0,
|
||||
retain=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("unexpected error: %s", str(e))
|
||||
|
Reference in New Issue
Block a user