feat(simulator): add simulator source

This commit is contained in:
Cyrille Nofficial 2022-10-25 16:59:18 +02:00 committed by Cyrille Nofficial
parent 24e4410c25
commit 667c6903ef

View File

@ -199,8 +199,59 @@ class CameraSource(Source):
return self._xout_rgb.getStreamName() return self._xout_rgb.getStreamName()
class MqttSource(Source):
"""Image source based onto mqtt stream"""
def __init__(self, device: dai.Device, pipeline: dai.Pipeline, mqtt_host: str, mqtt_topic: str,
mqtt_port: int = 1883, mqtt_qos: int = 0):
self._mqtt_host = mqtt_host
self._mqtt_port = mqtt_port
self._client = mqtt.Client()
self._client.user_data_set({"topic": mqtt_topic, "qos": str(mqtt_qos)})
self._client.on_connect = self._on_connect
self._client.on_message = self._on_message
self._img_in = pipeline.createXLinkIn()
self._img_in.setStreamName("img_input")
self._img_out = pipeline.createXLinkOut()
self._img_out.setStreamName("img_output")
self._img_in.out.link(self._img_out.input)
self._img_in_queue = device.getInputQueue(self._img_in.getStreamName())
def run(self):
self._client.connect(host=self._mqtt_host, port=self._mqtt_port)
self._client.loop_start()
def stop(self):
self._client.loop_stop()
self._client.disconnect()
@staticmethod @staticmethod
def _on_connect(client: mqtt.Client, userdata: dict[str, str], flags, rc):
# if we lose the connection and reconnect then subscriptions will be renewed.
client.subscribe(topic=userdata["topic"], qos=int(userdata["qos"]))
def _on_message(self, _: mqtt.Client, user_data: dict[str, str], msg: mqtt.MQTTMessage):
frame_msg = events.events_pb2.FrameMessage()
frame_msg.ParseFromString(msg.payload)
frame = np.asarray(frame_msg.frame, dtype="uint8")
frame = cv2.imdecode(frame, cv2.IMREAD_COLOR)
nn_data = dai.NNData()
nn_data.setLayer("data", _to_planar(frame, frame.shape()))
self._img_in_queue.send(nn_data)
def get_stream_name(self) -> str:
return self._img_out.getStreamName()
def link_preview(self, input_node: dai.Node.Input):
self._img_in.out.link(input_node)
def _to_planar(arr: np.ndarray, shape: tuple) -> list:
return [val for channel in cv2.resize(arr, shape).transpose(2, 0, 1) for y_col in channel for val in y_col]
class PipelineController: class PipelineController: