From d4fbf6ef39e4473e0dbe2403f8ada06f3e9bef05 Mon Sep 17 00:00:00 2001 From: fracapuano Date: Sat, 7 Jun 2025 00:47:11 +0200 Subject: [PATCH] add: support for videos generation in datasets --- tests/fixtures/dataset_factories.py | 81 ++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 18 deletions(-) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 08bfe10a4..ac79139fb 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -69,15 +69,49 @@ def img_tensor_factory(): @pytest.fixture(scope="session") def img_array_factory(): - def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: - if np.issubdtype(dtype, np.unsignedinteger): - # Int array in [0, 255] range - img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) - elif np.issubdtype(dtype, np.floating): - # Float array in [0, 1] range - img_array = np.random.rand(height, width, channels).astype(dtype) + def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray: + if content is None: + # Original random noise behavior + if np.issubdtype(dtype, np.unsignedinteger): + # Int array in [0, 255] range + img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) + elif np.issubdtype(dtype, np.floating): + # Float array in [0, 1] range + img_array = np.random.rand(height, width, channels).astype(dtype) + else: + raise ValueError(dtype) else: - raise ValueError(dtype) + # Create image with text content using OpenCV + import cv2 + + # Create white background + img_array = np.ones((height, width, channels), dtype=np.uint8) * 255 + + # Font settings + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = max(0.5, height / 200) # Scale font with image size + font_color = (0, 0, 0) # Black text + thickness = max(1, int(height / 100)) + + # Get text size to center it + text_size = cv2.getTextSize(content, font, font_scale, thickness)[0] + text_x = (width - text_size[0]) // 2 + text_y = (height + text_size[1]) // 2 + + # Put text on image + cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness) + + # Handle single channel case + if channels == 1: + img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) + img_array = img_array[:, :, np.newaxis] + + # Convert to target dtype + if np.issubdtype(dtype, np.floating): + img_array = img_array.astype(dtype) / 255.0 + else: + img_array = img_array.astype(dtype) + return img_array return _create_img_array @@ -97,7 +131,7 @@ def features_factory(): def _create_features( motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = False, + use_videos: bool = True, ) -> dict: if use_videos: camera_ft = { @@ -131,7 +165,7 @@ def info_factory(features_factory): video_path: str = DEFAULT_VIDEO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = False, + use_videos: bool = True, ) -> dict: features = features_factory(motor_features, camera_features, use_videos) return { @@ -320,17 +354,20 @@ def create_videos(info_factory, img_array_factory): video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"} for key, ft in video_feats.items(): - # create and save images + # create and save images with identifiable content tmp_dir = root / "tmp_images" tmp_dir.mkdir(parents=True, exist_ok=True) for frame_index in range(info["total_frames"]): - img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0]) + content = f"{key}-{frame_index}" + img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content) pil_img = PIL.Image.fromarray(img) path = tmp_dir / f"frame-{frame_index:06d}.png" pil_img.save(path) video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0) - encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"]) + video_path.parent.mkdir(parents=True, exist_ok=True) + # Use the global fps from info, not video-specific fps which might not exist + encode_video_frames(tmp_dir, video_path, fps=info["fps"]) shutil.rmtree(tmp_dir) return _create_video_directory @@ -372,8 +409,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar for key, ft in features.items(): if ft["dtype"] == "image": robot_cols[key] = [ - img_array_factory(height=ft["shape"][1], width=ft["shape"][0]) - for _ in range(len(index_col)) + img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}") + for i in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) @@ -464,6 +501,7 @@ def lerobot_dataset_factory( total_frames: int = 150, total_tasks: int = 1, multi_task: bool = False, + use_videos: bool = True, info: dict | None = None, stats: dict | None = None, tasks: pd.DataFrame | None = None, @@ -474,23 +512,30 @@ def lerobot_dataset_factory( # Instantiate objects if info is None: info = info_factory( - total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + total_episodes=total_episodes, + total_frames=total_frames, + total_tasks=total_tasks, + use_videos=use_videos, ) if stats is None: stats = stats_factory(features=info["features"]) if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) if episodes_metadata is None: + video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] episodes_metadata = episodes_factory( features=info["features"], fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], + video_keys=video_keys, tasks=tasks, multi_task=multi_task, ) - if not hf_dataset: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"]) + if hf_dataset is None: + hf_dataset = hf_dataset_factory( + features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"] + ) # Write data on disk mock_snapshot_download = mock_snapshot_download_factory(