fix: bad typing declaration

This commit is contained in:
Cyrille Nofficial 2023-04-19 19:56:45 +02:00
parent 013a5135ae
commit 415d39c5bb

View File

@ -187,7 +187,7 @@ def train(model_type: str, record_field: str, batch_size: int, slide_size: int,
# Save model for tensorflow using
model.save(f'/opt/ml/model/model_{record_field.replace("_", "")}_{model_type}_{str(img_width)}x{str(img_height)}h{str(horizon)}')
def representative_dataset() -> typing.Generator[typing.List[tf.float32], typing.Any, None]:
def representative_dataset() -> typing.Generator[typing.List[float], typing.Any, None]:
for d in tf.data.Dataset.from_tensor_slices(images).batch(1).take(100):
yield [tf.dtypes.cast(d, tf.float32)]