feat: check typing with mypy

This commit is contained in:
Cyrille Nofficial 2022-10-27 09:05:00 +02:00 committed by Cyrille Nofficial
parent 7670b8b01a
commit aed8e9f8c2
7 changed files with 4539 additions and 20 deletions

View File

@ -85,9 +85,7 @@ def execute_from_command_line() -> None:
objects_topic=args.mqtt_topic_robocar_objects, objects_topic=args.mqtt_topic_robocar_objects,
objects_threshold=args.objects_threshold) objects_threshold=args.objects_threshold)
pipeline_controller = cam.PipelineController(img_width=args.image_width, pipeline_controller = cam.PipelineController(frame_processor=frame_processor,
img_height=args.image_height,
frame_processor=frame_processor,
object_processor=object_processor) object_processor=object_processor)
def sigterm_handler(): def sigterm_handler():

View File

@ -11,6 +11,7 @@ import cv2
import depthai as dai import depthai as dai
import events.events_pb2 import events.events_pb2
import numpy as np import numpy as np
import numpy.typing as npt
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,8 +48,7 @@ class ObjectProcessor:
if boxes.shape[0] > 0: if boxes.shape[0] > 0:
self._publish_objects(boxes, frame_ref, scores) self._publish_objects(boxes, frame_ref, scores)
def _publish_objects(self, boxes: np.array, frame_ref, scores: np.array) -> None: def _publish_objects(self, boxes: npt.NDArray[np.float64], frame_ref, scores: npt.NDArray[np.float64]) -> None:
objects_msg = events.events_pb2.ObjectsMessage() objects_msg = events.events_pb2.ObjectsMessage()
objs = [] objs = []
for i in range(boxes.shape[0]): for i in range(boxes.shape[0]):
@ -276,11 +276,11 @@ class MqttSource(Source):
def get_stream_name(self) -> str: def get_stream_name(self) -> str:
return self._img_out.getStreamName() return self._img_out.getStreamName()
def link_preview(self, input_node: dai.Node.Input): def link(self, input_node: dai.Node.Input):
self._img_in.out.link(input_node) self._img_in.out.link(input_node)
def _to_planar(arr: np.ndarray, shape: tuple) -> list: def _to_planar(arr: npt.NDArray[int], shape: tuple[int, int]) -> list[int]:
return [val for channel in cv2.resize(arr, shape).transpose(2, 0, 1) for y_col in channel for val in y_col] return [val for channel in cv2.resize(arr, shape).transpose(2, 0, 1) for y_col in channel for val in y_col]
@ -289,7 +289,7 @@ class PipelineController:
Pipeline controller that drive camera device Pipeline controller that drive camera device
""" """
def __init__(self, img_width: int, img_height: int, frame_processor: FrameProcessor, def __init__(self, frame_processor: FrameProcessor,
object_processor: ObjectProcessor, camera: Source, object_node: ObjectDetectionNN): object_processor: ObjectProcessor, camera: Source, object_node: ObjectDetectionNN):
self._pipeline = self._configure_pipeline() self._pipeline = self._configure_pipeline()
self._frame_processor = frame_processor self._frame_processor = frame_processor
@ -319,8 +319,8 @@ class PipelineController:
with dai.Device(pipeline=self._pipeline) as device: with dai.Device(pipeline=self._pipeline) as device:
logger.info('MxId: %s', device.getDeviceInfo().getMxId()) logger.info('MxId: %s', device.getDeviceInfo().getMxId())
logger.info('USB speed: %s', device.getUsbSpeed()) logger.info('USB speed: %s', device.getUsbSpeed())
logger.info('Connected cameras: %s', device.getConnectedCameras()) logger.info('Connected cameras: %s', str(device.getConnectedCameras()))
logger.info("output queues found: %s", device.getOutputQueueNames()) logger.info("output queues found: %s", str(device.getOutputQueueNames()))
device.startPipeline() device.startPipeline()
# Queues # Queues
@ -361,7 +361,7 @@ class PipelineController:
self._stop = True self._stop = True
def _bbox_to_object(bbox: np.array, score: float) -> events.events_pb2.Object: def _bbox_to_object(bbox: npt.NDArray[np.float64], score: float) -> events.events_pb2.Object:
obj = events.events_pb2.Object() obj = events.events_pb2.Object()
obj.type = events.events_pb2.TypeObject.ANY obj.type = events.events_pb2.TypeObject.ANY
obj.top = bbox[0].astype(float) obj.top = bbox[0].astype(float)

View File

@ -58,12 +58,12 @@ class TestObjectProcessor:
def raw_objects_one(self, mocker: pytest_mock.MockerFixture, object1: Object) -> dai.NNData: def raw_objects_one(self, mocker: pytest_mock.MockerFixture, object1: Object) -> dai.NNData:
def mock_return(name): def mock_return(name):
if name == "ExpandDims": # Detection boxes if name == "ExpandDims": # Detection boxes
boxes = [[0] * 4] * 100 boxes: list[list[float]] = [[0.] * 4] * 100
boxes[0] = [object1["top"], object1["left"], object1["bottom"], object1["right"]] boxes[0] = [object1["top"], object1["left"], object1["bottom"], object1["right"]]
return np.array(boxes) return np.array(boxes)
elif name == "ExpandDims_2": # Detection scores elif name == "ExpandDims_2": # Detection scores
scores = [0] * 100 scores: list[float] = [0.] * 100
scores[0] = object1["score"] scores[0] = object1["score"]
return scores return scores
else: else:

3627
cv2/__init__.pyi Normal file

File diff suppressed because one or more lines are too long

2
mypy.ini Normal file
View File

@ -0,0 +1,2 @@
[mypy]
plugins = numpy.typing.mypy_plugin

894
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,6 @@ authors = ["Cyrille Nofficial <cynoffic@cyrilix.fr>"]
readme = "README.md" readme = "README.md"
packages = [ packages = [
{ include = "camera" }, { include = "camera" },
{ include = "events" },
] ]
[tool.poetry.dependencies] [tool.poetry.dependencies]
@ -18,6 +17,7 @@ google = "^3.0.0"
blobconverter = "^1.3.0" blobconverter = "^1.3.0"
protobuf = "^4.21.8" protobuf = "^4.21.8"
opencv-python-headless = "^4.6.0.66" opencv-python-headless = "^4.6.0.66"
robocar-protobuf = { version = "^1.1.1", source = "robocar" }
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
@ -27,6 +27,16 @@ pytest-mock = "^3.10.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pylint = "^2.15.4" pylint = "^2.15.4"
mypy = "^0.982"
types-paho-mqtt = "^1.6.0.1"
types-protobuf = "^3.20.4.2"
[[tool.poetry.source]]
name = "robocar"
url = "https://git.cyrilix.bzh/api/packages/robocars/pypi/simple"
default = false
secondary = false
[build-system] [build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]