In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing)

This commit is contained in:
Remi Cadene
2025-05-12 15:37:02 +02:00
committed by Michel Aractingi
parent 220997ff47
commit 58795d72c8
5 changed files with 71 additions and 25 deletions
+3 -1
View File
@@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
concat_video_files, concat_video_files,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
get_video_size_in_mb, get_video_size_in_mb,
safe_write_dataframe_to_parquet,
update_chunk_file_indices, update_chunk_file_indices,
write_info, write_info,
write_stats, write_stats,
@@ -97,6 +98,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
fps, robot_type, features = validate_all_metadata(all_metadata) fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"] video_keys = [key for key in features if features[key]["dtype"] == "video"]
image_keys = [key for key in features if features[key]["dtype"] == "image"]
# Create resulting dataset folder # Create resulting dataset folder
aggr_meta = LeRobotDatasetMetadata.create( aggr_meta = LeRobotDatasetMetadata.create(
@@ -259,7 +261,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
# Update the existing parquet file with new rows # Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path) aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True) df = pd.concat([aggr_df, df], ignore_index=True)
df.to_parquet(aggr_path) safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
num_episodes += meta.total_episodes num_episodes += meta.total_episodes
num_frames += meta.total_frames num_frames += meta.total_frames
+1 -1
View File
@@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init" root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init) dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
init_attr = set(vars(dataset_init).keys()) init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys()) create_attr = set(vars(dataset_create).keys())
+2 -2
View File
@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = { DUMMY_VIDEO_INFO = {
+39 -5
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import shutil
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Protocol from typing import Protocol
@@ -37,6 +38,7 @@ from lerobot.datasets.utils import (
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import ( from tests.fixtures.constants import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES,
@@ -95,7 +97,7 @@ def features_factory():
def _create_features( def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = False,
) -> dict: ) -> dict:
if use_videos: if use_videos:
camera_ft = { camera_ft = {
@@ -129,7 +131,7 @@ def info_factory(features_factory):
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = False,
) -> dict: ) -> dict:
features = features_factory(motor_features, camera_features, use_videos) features = features_factory(motor_features, camera_features, use_videos)
return { return {
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
return _create_episodes return _create_episodes
@pytest.fixture(scope="session")
def create_videos(info_factory, img_array_factory):
def _create_video_directory(
root: Path,
info: dict | None = None,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
):
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
pil_img = PIL.Image.fromarray(img)
path = tmp_dir / f"frame-{frame_index:06d}.png"
pil_img.save(path)
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
shutil.rmtree(tmp_dir)
return _create_video_directory
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset( def _create_hf_dataset(
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for key, ft in features.items(): for key, ft in features.items():
if ft["dtype"] == "image": if ft["dtype"] == "image":
robot_cols[key] = [ robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
for _ in range(len(index_col)) for _ in range(len(index_col))
] ]
elif ft["shape"][0] > 1 and ft["dtype"] != "video": elif ft["shape"][0] > 1 and ft["dtype"] != "video":
@@ -437,6 +471,7 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
**kwargs, **kwargs,
) -> LeRobotDataset: ) -> LeRobotDataset:
# Instantiate objects
if info is None: if info is None:
info = info_factory( info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
@@ -446,19 +481,18 @@ def lerobot_dataset_factory(
if tasks is None: if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if episodes_metadata is None: if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory( episodes_metadata = episodes_factory(
features=info["features"], features=info["features"],
fps=info["fps"], fps=info["fps"],
total_episodes=info["total_episodes"], total_episodes=info["total_episodes"],
total_frames=info["total_frames"], total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks, tasks=tasks,
multi_task=multi_task, multi_task=multi_task,
) )
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"]) hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
# Write data on disk
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info=info, info=info,
stats=stats, stats=stats,
+26 -16
View File
@@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH, DEFAULT_TASKS_PATH,
DEFAULT_VIDEO_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
) )
@@ -40,6 +41,7 @@ def mock_snapshot_download_factory(
create_episodes, create_episodes,
hf_dataset_factory, hf_dataset_factory,
create_hf_dataset, create_hf_dataset,
create_videos,
): ):
""" """
This factory allows to patch snapshot_download such that when called, it will create expected files rather This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -91,40 +93,48 @@ def mock_snapshot_download_factory(
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
] ]
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
allowed_files = filter_repo_objects( allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
) )
has_info = False request_info = False
has_tasks = False request_tasks = False
has_episodes = False request_episodes = False
has_stats = False request_stats = False
has_data = False request_data = False
request_videos = False
for rel_path in allowed_files: for rel_path in allowed_files:
if rel_path.startswith("meta/info.json"): if rel_path.startswith("meta/info.json"):
has_info = True request_info = True
elif rel_path.startswith("meta/stats"): elif rel_path.startswith("meta/stats"):
has_stats = True request_stats = True
elif rel_path.startswith("meta/tasks"): elif rel_path.startswith("meta/tasks"):
has_tasks = True request_tasks = True
elif rel_path.startswith("meta/episodes"): elif rel_path.startswith("meta/episodes"):
has_episodes = True request_episodes = True
elif rel_path.startswith("data/"): elif rel_path.startswith("data/"):
has_data = True request_data = True
elif rel_path.startswith("videos/"):
request_videos = True
else: else:
raise ValueError(f"{rel_path} not supported.") raise ValueError(f"{rel_path} not supported.")
if has_info: if request_info:
create_info(local_dir, info) create_info(local_dir, info)
if has_stats: if request_stats:
create_stats(local_dir, stats) create_stats(local_dir, stats)
if has_tasks: if request_tasks:
create_tasks(local_dir, tasks) create_tasks(local_dir, tasks)
if has_episodes: if request_episodes:
create_episodes(local_dir, episodes) create_episodes(local_dir, episodes)
# TODO(rcadene): create_videos? if request_data:
if has_data:
create_hf_dataset(local_dir, hf_dataset) create_hf_dataset(local_dir, hf_dataset)
if request_videos:
create_videos(root=local_dir, info=info)
return str(local_dir) return str(local_dir)