refactor: split pipeline code

This commit is contained in:
Cyrille Nofficial 2022-10-20 16:57:33 +02:00 committed by Cyrille Nofficial
parent b50b54be34
commit 9918c8c413
3 changed files with 116 additions and 68 deletions

View File

@ -468,7 +468,7 @@ max-returns=6
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
min-public-methods=1
[STRING]

View File

@ -75,13 +75,16 @@ def execute_from_command_line() -> None:
password=args.mqtt_password,
client_id=args.mqtt_client_id,
)
frame_processor = cam.FramePublisher(mqtt_client=client,
frame_topic=args.mqtt_topic_robocar_oak_camera,
frame_processor = cam.FrameProcessor(mqtt_client=client, frame_topic=args.mqtt_topic_robocar_oak_camera)
object_processor = cam.ObjectProcessor(mqtt_client=client,
objects_topic=args.mqtt_topic_robocar_objects,
objects_threshold=args.objects_threshold,
img_width=args.image_width,
img_height=args.image_height)
frame_processor.run()
objects_threshold=args.objects_threshold)
pipeline_controller = cam.PipelineController(img_width=args.image_width,
img_height=args.image_height,
frame_processor=frame_processor,
object_processor=object_processor)
pipeline_controller.run()
def _get_env_value(env_var: str, default_value: str) -> str:

View File

@ -3,6 +3,7 @@ Camera event loop
"""
import datetime
import logging
import typing
import cv2
import depthai as dai
@ -18,20 +19,98 @@ NN_WIDTH = 192
NN_HEIGHT = 192
class FramePublisher:
class ObjectProcessor:
"""
Camera controller that publish events from camera
Processor for Object detection
"""
def __init__(self, mqtt_client: mqtt.Client, frame_topic: str, objects_topic: str, objects_threshold: float,
img_width: int, img_height: int):
def __init__(self, mqtt_client: mqtt.Client, objects_topic: str, objects_threshold: float):
self._mqtt_client = mqtt_client
self._frame_topic = frame_topic
self._objects_topic = objects_topic
self._objects_threshold = objects_threshold
def process(self, in_nn: dai.NNData, frame_ref, frame_date: datetime.datetime) -> None:
"""
Parse and publish result of NeuralNetwork result
:param in_nn: NeuralNetwork result read from device
:param frame_ref: Id of the frame where objects are been detected
:param frame_date: Datetime of the frame used for detection
:return:
"""
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
detection_scores = np.array(in_nn.getLayerFp16("ExpandDims_2")).reshape((100,))
# keep boxes bigger than threshold
mask = detection_scores >= self._objects_threshold
boxes = detection_boxes[mask]
scores = detection_scores[mask]
if boxes.shape[0] > 0:
self._publish_objects(boxes, frame_ref, frame_date, scores)
def _publish_objects(self, boxes: np.array, frame_ref, now: datetime.datetime, scores: np.array) -> None:
objects_msg = events.events_pb2.ObjectsMessage()
objs = []
for i in range(boxes.shape[0]):
logger.debug("new object detected: %s", str(boxes[i]))
objs.append(_bbox_to_object(boxes[i], scores[i].astype(float)))
objects_msg.objects.extend(objs)
objects_msg.frame_ref.name = frame_ref.name
objects_msg.frame_ref.id = frame_ref.id
objects_msg.frame_ref.created_at.FromDatetime(now)
logger.debug("publish object event to %s", self._objects_topic)
self._mqtt_client.publish(topic=self._objects_topic,
payload=objects_msg.SerializeToString(),
qos=0,
retain=False)
class FrameProcessor:
"""
Processor for camera frames
"""
def __init__(self, mqtt_client: mqtt.Client, frame_topic: str):
self._mqtt_client = mqtt_client
self._frame_topic = frame_topic
def process(self, img: dai.ImgFrame) -> (typing.Any, datetime.datetime):
"""
Publish camera frames
:param img:
:return:
id frame
frame creation datetime
"""
im_resize = img.getCvFrame()
is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
byte_im = im_buf_arr.tobytes()
now = datetime.datetime.now()
frame_msg = events.events_pb2.FrameMessage()
frame_msg.id.name = "robocar-oak-camera-oak"
frame_msg.id.id = str(int(now.timestamp() * 1000))
frame_msg.id.created_at.FromDatetime(now)
frame_msg.frame = byte_im
logger.debug("publish frame event to %s", self._frame_topic)
self._mqtt_client.publish(topic=self._frame_topic,
payload=frame_msg.SerializeToString(),
qos=0,
retain=False)
return frame_msg.id, now
class PipelineController:
"""
Pipeline controller that drive camera device
"""
def __init__(self, img_width: int, img_height: int, frame_processor: FrameProcessor,
object_processor: ObjectProcessor):
self._img_width = img_width
self._img_height = img_height
self._pipeline = self._configure_pipeline()
self._frame_processor = frame_processor
self._object_processor = object_processor
self._stop = False
def _configure_pipeline(self) -> dai.Pipeline:
@ -40,16 +119,8 @@ class FramePublisher:
pipeline.setOpenVINOVersion(version=dai.OpenVINO.VERSION_2021_4)
# Define a neural network that will make predictions based on the source frames
detection_nn = pipeline.create(dai.node.NeuralNetwork)
detection_nn.setBlobPath(NN_PATH)
detection_nn.setNumPoolFrames(4)
detection_nn.input.setBlocking(False)
detection_nn.setNumInferenceThreads(2)
xout_nn = pipeline.create(dai.node.XLinkOut)
xout_nn.setStreamName("nn")
xout_nn.input.setBlocking(False)
detection_nn = self._configure_detection_nn(pipeline)
xout_nn = self._configure_xout_nn(pipeline)
# Resize image
manip = pipeline.create(dai.node.ImageManip)
@ -79,6 +150,23 @@ class FramePublisher:
logger.info("pipeline configured")
return pipeline
@staticmethod
def _configure_xout_nn(pipeline: dai.Pipeline) -> dai.node.XLinkOut:
xout_nn = pipeline.create(dai.node.XLinkOut)
xout_nn.setStreamName("nn")
xout_nn.input.setBlocking(False)
return xout_nn
@staticmethod
def _configure_detection_nn(pipeline: dai.Pipeline) -> dai.node.NeuralNetwork:
# Define a neural network that will make predictions based on the source frames
detection_nn = pipeline.create(dai.node.NeuralNetwork)
detection_nn.setBlobPath(NN_PATH)
detection_nn.setNumPoolFrames(4)
detection_nn.input.setBlocking(False)
detection_nn.setNumInferenceThreads(2)
return detection_nn
def run(self) -> None:
"""
Start event loop
@ -89,7 +177,6 @@ class FramePublisher:
logger.info('MxId: %s', device.getDeviceInfo().getMxId())
logger.info('USB speed: %s', device.getUsbSpeed())
logger.info('Connected cameras: %s', device.getConnectedCameras())
logger.info("output queues found: %s", device.getOutputQueueNames())
device.startPipeline()
@ -113,53 +200,11 @@ class FramePublisher:
# Wait for frame
in_rgb = _get_as_imgframe(q_rgb) # blocking call, will wait until a new data has arrived
frame_msg, now = self._read_and_publish_frame(in_rgb)
frame_msg, now = self._frame_processor.process(in_rgb)
# Read NN result
in_nn = _get_as_nndata(q_nn)
# get outputs
detection_boxes = np.array(in_nn.getLayerFp16("ExpandDims")).reshape((100, 4))
detection_scores = np.array(in_nn.getLayerFp16("ExpandDims_2")).reshape((100,))
# keep boxes bigger than threshold
mask = detection_scores >= self._objects_threshold
boxes = detection_boxes[mask]
scores = detection_scores[mask]
if boxes.shape[0] > 0:
self._publish_objects(boxes, frame_msg, now, scores)
def _read_and_publish_frame(self, in_rgb: dai.ImgFrame) -> (events.events_pb2.FrameMessage, datetime.datetime):
im_resize = in_rgb.getCvFrame()
is_success, im_buf_arr = cv2.imencode(".jpg", im_resize)
byte_im = im_buf_arr.tobytes()
now = datetime.datetime.now()
frame_msg = events.events_pb2.FrameMessage()
frame_msg.id.name = "robocar-oak-camera-oak"
frame_msg.id.id = str(int(now.timestamp() * 1000))
frame_msg.id.created_at.FromDatetime(now)
frame_msg.frame = byte_im
logger.debug("publish frame event to %s", self._frame_topic)
self._mqtt_client.publish(topic=self._frame_topic,
payload=frame_msg.SerializeToString(),
qos=0,
retain=False)
return frame_msg, now
def _publish_objects(self, boxes, frame_msg, now, scores):
objects_msg = events.events_pb2.ObjectsMessage()
objs = []
for i in range(boxes.shape[0]):
logger.debug("new object detected: %s", str(boxes[i]))
objs.append(_bbox_to_object(boxes[i], scores[i].astype(float)))
objects_msg.objects.extend(objs)
objects_msg.frame_ref.name = frame_msg.id.name
objects_msg.frame_ref.id = frame_msg.id.id
objects_msg.frame_ref.created_at.FromDatetime(now)
logger.debug("publish object event to %s", self._frame_topic)
self._mqtt_client.publish(topic=self._objects_topic,
payload=objects_msg.SerializeToString(),
qos=0,
retain=False)
self._object_processor.process(in_nn, frame_msg.id, now)
def stop(self):
"""