refactor: split pipeline code
This commit is contained in:
parent
b50b54be34
commit
9918c8c413
@ -468,7 +468,7 @@ max-returns=6
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
min-public-methods=1
|
||||
|
||||
|
||||
[STRING]
|
||||
|
@ -75,13 +75,16 @@ def execute_from_command_line() -> None:
|
||||
password=args.mqtt_password,
|
||||
client_id=args.mqtt_client_id,
|
||||
)
|
||||
frame_processor = cam.FramePublisher(mqtt_client=client,
|
||||
frame_topic=args.mqtt_topic_robocar_oak_camera,
|
||||
objects_topic=args.mqtt_topic_robocar_objects,
|
||||
objects_threshold=args.objects_threshold,
|
||||
img_width=args.image_width,
|
||||
img_height=args.image_height)
|
||||
frame_processor.run()
|
||||
frame_processor = cam.FrameProcessor(mqtt_client=client, frame_topic=args.mqtt_topic_robocar_oak_camera)
|
||||
object_processor = cam.ObjectProcessor(mqtt_client=client,
|
||||
objects_topic=args.mqtt_topic_robocar_objects,
|
||||
objects_threshold=args.objects_threshold)
|
||||
|
||||
pipeline_controller = cam.PipelineController(img_width=args.image_width,
|
||||
img_height=args.image_height,
|
||||
frame_processor=frame_processor,
|
||||
object_processor=object_processor)
|
||||
pipeline_controller.run()
|
||||
|
||||
|
||||
def _get_env_value(env_var: str, default_value: str) -> str:
|
||||
|
@ -3,6 +3,7 @@ Camera event loop
|
||||
"""
|
||||
import datetime
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import cv2
|
||||
import depthai as dai
|
||||
@ -18,20 +19,98 @@ NN_WIDTH = 192
|
||||
NN_HEIGHT = 192
|
||||
|
||||
|
||||
class FramePublisher:
|
||||
class ObjectProcessor:
|
||||
"""
|
||||
Camera controller that publish events from camera
|
||||
Processor for Object detection
|
||||
"""
|
||||
|
||||
def __init__(self, mqtt_client: mqtt.Client, frame_topic: str, objects_topic: str, objects_threshold: float,
|
||||
img_width: int, img_height: int):
|
||||
def __init__(self, mqtt_client: mqtt.Client, objects_topic: str, objects_threshold: float):
|
||||
self._mqtt_client = mqtt_client
|
||||
self._frame_topic = frame_topic
|
||||
self._objects_topic = objects_topic
|
||||
self._objects_threshold = objects_threshold
|
||||
|
||||
def process(self, in_nn: dai.NNData, frame_ref, frame_date: datetime.datetime) -> None:
|
||||
"""
|
||||
Parse and publish result of NeuralNetwork result
|
||||
:param in_nn: NeuralNetwork result read from device
|
||||
:param frame_ref: Id of the frame where objects are been detected
|
||||
:param frame_date: Datetime of the frame used for detection
|
||||
:return:
|
||||
"""
|
||||
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
|
||||
detection_scores = np.array(in_nn.getLayerFp16("ExpandDims_2")).reshape((100,))
|
||||
# keep boxes bigger than threshold
|
||||
mask = detection_scores >= self._objects_threshold
|
||||
boxes = detection_boxes[mask]
|
||||
scores = detection_scores[mask]
|
||||
|
||||
if boxes.shape[0] > 0:
|
||||
self._publish_objects(boxes, frame_ref, frame_date, scores)
|
||||
|
||||
def _publish_objects(self, boxes: np.array, frame_ref, now: datetime.datetime, scores: np.array) -> None:
|
||||
objects_msg = events.events_pb2.ObjectsMessage()
|
||||
objs = []
|
||||
for i in range(boxes.shape[0]):
|
||||
logger.debug("new object detected: %s", str(boxes[i]))
|
||||
objs.append(_bbox_to_object(boxes[i], scores[i].astype(float)))
|
||||
objects_msg.objects.extend(objs)
|
||||
objects_msg.frame_ref.name = frame_ref.name
|
||||
objects_msg.frame_ref.id = frame_ref.id
|
||||
objects_msg.frame_ref.created_at.FromDatetime(now)
|
||||
logger.debug("publish object event to %s", self._objects_topic)
|
||||
self._mqtt_client.publish(topic=self._objects_topic,
|
||||
payload=objects_msg.SerializeToString(),
|
||||
qos=0,
|
||||
retain=False)
|
||||
|
||||
|
||||
class FrameProcessor:
|
||||
"""
|
||||
Processor for camera frames
|
||||
"""
|
||||
|
||||
def __init__(self, mqtt_client: mqtt.Client, frame_topic: str):
|
||||
self._mqtt_client = mqtt_client
|
||||
self._frame_topic = frame_topic
|
||||
|
||||
def process(self, img: dai.ImgFrame) -> (typing.Any, datetime.datetime):
|
||||
"""
|
||||
Publish camera frames
|
||||
:param img:
|
||||
:return:
|
||||
id frame
|
||||
frame creation datetime
|
||||
"""
|
||||
im_resize = img.getCvFrame()
|
||||
is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
|
||||
byte_im = im_buf_arr.tobytes()
|
||||
|
||||
now = datetime.datetime.now()
|
||||
frame_msg = events.events_pb2.FrameMessage()
|
||||
frame_msg.id.name = "robocar-oak-camera-oak"
|
||||
frame_msg.id.id = str(int(now.timestamp() * 1000))
|
||||
frame_msg.id.created_at.FromDatetime(now)
|
||||
frame_msg.frame = byte_im
|
||||
logger.debug("publish frame event to %s", self._frame_topic)
|
||||
self._mqtt_client.publish(topic=self._frame_topic,
|
||||
payload=frame_msg.SerializeToString(),
|
||||
qos=0,
|
||||
retain=False)
|
||||
return frame_msg.id, now
|
||||
|
||||
|
||||
class PipelineController:
|
||||
"""
|
||||
Pipeline controller that drive camera device
|
||||
"""
|
||||
|
||||
def __init__(self, img_width: int, img_height: int, frame_processor: FrameProcessor,
|
||||
object_processor: ObjectProcessor):
|
||||
self._img_width = img_width
|
||||
self._img_height = img_height
|
||||
self._pipeline = self._configure_pipeline()
|
||||
self._frame_processor = frame_processor
|
||||
self._object_processor = object_processor
|
||||
self._stop = False
|
||||
|
||||
def _configure_pipeline(self) -> dai.Pipeline:
|
||||
@ -40,16 +119,8 @@ class FramePublisher:
|
||||
|
||||
pipeline.setOpenVINOVersion(version=dai.OpenVINO.VERSION_2021_4)
|
||||
|
||||
# Define a neural network that will make predictions based on the source frames
|
||||
detection_nn = pipeline.create(dai.node.NeuralNetwork)
|
||||
detection_nn.setBlobPath(NN_PATH)
|
||||
detection_nn.setNumPoolFrames(4)
|
||||
detection_nn.input.setBlocking(False)
|
||||
detection_nn.setNumInferenceThreads(2)
|
||||
|
||||
xout_nn = pipeline.create(dai.node.XLinkOut)
|
||||
xout_nn.setStreamName("nn")
|
||||
xout_nn.input.setBlocking(False)
|
||||
detection_nn = self._configure_detection_nn(pipeline)
|
||||
xout_nn = self._configure_xout_nn(pipeline)
|
||||
|
||||
# Resize image
|
||||
manip = pipeline.create(dai.node.ImageManip)
|
||||
@ -79,6 +150,23 @@ class FramePublisher:
|
||||
logger.info("pipeline configured")
|
||||
return pipeline
|
||||
|
||||
@staticmethod
|
||||
def _configure_xout_nn(pipeline: dai.Pipeline) -> dai.node.XLinkOut:
|
||||
xout_nn = pipeline.create(dai.node.XLinkOut)
|
||||
xout_nn.setStreamName("nn")
|
||||
xout_nn.input.setBlocking(False)
|
||||
return xout_nn
|
||||
|
||||
@staticmethod
|
||||
def _configure_detection_nn(pipeline: dai.Pipeline) -> dai.node.NeuralNetwork:
|
||||
# Define a neural network that will make predictions based on the source frames
|
||||
detection_nn = pipeline.create(dai.node.NeuralNetwork)
|
||||
detection_nn.setBlobPath(NN_PATH)
|
||||
detection_nn.setNumPoolFrames(4)
|
||||
detection_nn.input.setBlocking(False)
|
||||
detection_nn.setNumInferenceThreads(2)
|
||||
return detection_nn
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Start event loop
|
||||
@ -89,7 +177,6 @@ class FramePublisher:
|
||||
logger.info('MxId: %s', device.getDeviceInfo().getMxId())
|
||||
logger.info('USB speed: %s', device.getUsbSpeed())
|
||||
logger.info('Connected cameras: %s', device.getConnectedCameras())
|
||||
|
||||
logger.info("output queues found: %s", device.getOutputQueueNames())
|
||||
|
||||
device.startPipeline()
|
||||
@ -113,53 +200,11 @@ class FramePublisher:
|
||||
|
||||
# Wait for frame
|
||||
in_rgb = _get_as_imgframe(q_rgb) # blocking call, will wait until a new data has arrived
|
||||
frame_msg, now = self._read_and_publish_frame(in_rgb)
|
||||
frame_msg, now = self._frame_processor.process(in_rgb)
|
||||
|
||||
# Read NN result
|
||||
in_nn = _get_as_nndata(q_nn)
|
||||
# get outputs
|
||||
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
|
||||
detection_scores = np.array(in_nn.getLayerFp16("ExpandDims_2")).reshape((100,))
|
||||
# keep boxes bigger than threshold
|
||||
mask = detection_scores >= self._objects_threshold
|
||||
boxes = detection_boxes[mask]
|
||||
scores = detection_scores[mask]
|
||||
|
||||
if boxes.shape[0] > 0:
|
||||
self._publish_objects(boxes, frame_msg, now, scores)
|
||||
|
||||
def _read_and_publish_frame(self, in_rgb: dai.ImgFrame) -> (events.events_pb2.FrameMessage, datetime.datetime):
|
||||
im_resize = in_rgb.getCvFrame()
|
||||
is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
|
||||
byte_im = im_buf_arr.tobytes()
|
||||
now = datetime.datetime.now()
|
||||
frame_msg = events.events_pb2.FrameMessage()
|
||||
frame_msg.id.name = "robocar-oak-camera-oak"
|
||||
frame_msg.id.id = str(int(now.timestamp() * 1000))
|
||||
frame_msg.id.created_at.FromDatetime(now)
|
||||
frame_msg.frame = byte_im
|
||||
logger.debug("publish frame event to %s", self._frame_topic)
|
||||
self._mqtt_client.publish(topic=self._frame_topic,
|
||||
payload=frame_msg.SerializeToString(),
|
||||
qos=0,
|
||||
retain=False)
|
||||
return frame_msg, now
|
||||
|
||||
def _publish_objects(self, boxes, frame_msg, now, scores):
|
||||
objects_msg = events.events_pb2.ObjectsMessage()
|
||||
objs = []
|
||||
for i in range(boxes.shape[0]):
|
||||
logger.debug("new object detected: %s", str(boxes[i]))
|
||||
objs.append(_bbox_to_object(boxes[i], scores[i].astype(float)))
|
||||
objects_msg.objects.extend(objs)
|
||||
objects_msg.frame_ref.name = frame_msg.id.name
|
||||
objects_msg.frame_ref.id = frame_msg.id.id
|
||||
objects_msg.frame_ref.created_at.FromDatetime(now)
|
||||
logger.debug("publish object event to %s", self._frame_topic)
|
||||
self._mqtt_client.publish(topic=self._objects_topic,
|
||||
payload=objects_msg.SerializeToString(),
|
||||
qos=0,
|
||||
retain=False)
|
||||
self._object_processor.process(in_nn, frame_msg.id, now)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user