Compare commits

..

2 Commits

Author SHA1 Message Date
Caroline Pascal a518548a17 fix(docstrings): improving docstrings
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-12 20:02:59 +02:00
CarolinePascal 0344874a28 fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion when videos are not used 2026-06-12 16:17:26 +02:00
9 changed files with 32 additions and 131 deletions
+6 -21
View File
@@ -286,8 +286,6 @@ def aggregate_datasets(
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
chunk_size: int | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -305,8 +303,6 @@ def aggregate_datasets(
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
logging.info("Start aggregate_datasets")
@@ -355,12 +351,8 @@ def aggregate_datasets(
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
)
data_idx = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
)
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
@@ -375,9 +367,7 @@ def aggregate_datasets(
logging.info("Aggregation complete.")
def aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
):
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
Handles video file concatenation and rotation based on file size limits.
@@ -389,7 +379,6 @@ def aggregate_videos(
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -450,7 +439,7 @@ def aggregate_videos(
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
@@ -488,7 +477,7 @@ def aggregate_videos(
return videos_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
@@ -504,7 +493,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -550,7 +538,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=contains_images,
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
)
# Record the mapping from source to actual destination
@@ -627,7 +614,6 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -644,7 +630,6 @@ def append_or_create_parquet_file(
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -664,7 +649,7 @@ def append_or_create_parquet_file(
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if not concatenate or dst_size + src_size >= max_mb:
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
-2
View File
@@ -59,8 +59,6 @@ class RunningQuantileStats:
batch: An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
# Promote integer and low-precision inputs before computing squared statistics.
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
num_elements, vector_length = batch.shape
if self._count == 0:
-6
View File
@@ -261,8 +261,6 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -272,8 +270,6 @@ def merge_datasets(
datasets: List of LeRobotDatasets to merge.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -288,8 +284,6 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
concatenate_videos=concatenate_videos,
concatenate_data=concatenate_data,
)
merged_dataset = LeRobotDataset(
+5 -3
View File
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `use_videos` and `patterns`, and finally
(image or state), filters them based on `exclude_images` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: If False, image features are excluded.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
)
# 2. Apply filtering rules.
if is_image and not use_videos:
if is_image and exclude_images:
continue
if not is_image and not should_keep(key, compiled_patterns):
continue
@@ -94,14 +94,6 @@ Merge multiple datasets from a list of local dataset paths:
--operation.repo_ids "['pusht_train', 'pusht_val']" \
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
Merge multiple datasets while keeping one file per source file (no video/data stitching):
lerobot-edit-dataset \
--new_repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" \
--operation.concatenate_videos false \
--operation.concatenate_data false
Remove camera feature:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
@@ -265,9 +257,6 @@ class SplitConfig(OperationConfig):
class MergeConfig(OperationConfig):
repo_ids: list[str] | None = None
roots: list[str] | None = None
# When False, keep one file per source file instead of packing into shards.
concatenate_videos: bool = True
concatenate_data: bool = True
@OperationConfig.register_subclass("remove_feature")
@@ -472,8 +461,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
datasets,
output_repo_id=cfg.new_repo_id,
output_dir=output_dir,
concatenate_videos=cfg.operation.concatenate_videos,
concatenate_data=cfg.operation.concatenate_data,
)
logging.info(f"Merged dataset saved to {output_dir}")
-46
View File
@@ -289,52 +289,6 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_datasets_without_concatenation(tmp_path, lerobot_dataset_factory):
"""With concatenation disabled, each source file is kept as its own destination file."""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "no_stitch_0",
repo_id=f"{DUMMY_REPO_ID}_no_stitch_0",
total_episodes=3,
total_frames=60,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "no_stitch_1",
repo_id=f"{DUMMY_REPO_ID}_no_stitch_1",
total_episodes=4,
total_frames=80,
)
aggr_root = tmp_path / "no_stitch_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_no_stitch_aggr",
aggr_root=aggr_root,
concatenate_videos=False,
concatenate_data=False,
)
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(aggr_root)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_no_stitch_aggr", root=aggr_root)
assert_episode_and_frame_counts(
aggr_ds, ds_0.num_episodes + ds_1.num_episodes, ds_0.num_frames + ds_1.num_frames
)
assert_dataset_iteration_works(aggr_ds)
assert_video_timestamps_within_bounds(aggr_ds)
# Two single-file sources stay as two files each, instead of being packed together.
assert len(list((aggr_root / "data").rglob("*.parquet"))) == 2
assert aggr_ds.meta.video_keys, "Test fixture should produce at least one video feature"
for key in aggr_ds.meta.video_keys:
assert len(list((aggr_root / "videos" / key).rglob("*.mp4"))) == 2
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
tmp_path, lerobot_dataset_factory, caplog, mutation
-23
View File
@@ -83,29 +83,6 @@ def test_get_feature_stats_images():
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_uint8_images_preserves_std():
data = np.array(
[
[
[[0, 64], [128, 255]],
[[255, 128], [64, 0]],
[[32, 96], [160, 224]],
],
[
[[16, 80], [144, 240]],
[[240, 144], [80, 16]],
[[48, 112], [176, 208]],
],
],
dtype=np.uint8,
)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1)
np.testing.assert_allclose(stats["std"], expected_std)
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
+21 -3
View File
@@ -2370,14 +2370,32 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # expect "image" dtype
use_videos=False, # images kept, stored as "image" dtype
patterns=None,
)
key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front"
assert key not in out
assert key_front not in out
assert key in out
assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
def test_aggregate_images_when_use_videos_true():
@@ -66,20 +66,6 @@ class TestOperationTypeParsing:
with pytest.raises(ValueError, match="--new_repo_id is required for merge"):
_validate_config(cfg)
@pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"])
def test_merge_concatenate_flag_defaults_true(self, flag):
cfg = parse_cfg(["--new_repo_id", "test/merged", "--operation.type", "merge"])
assert isinstance(cfg.operation, MergeConfig)
assert getattr(cfg.operation, flag) is True
@pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"])
def test_merge_concatenate_flag_can_be_disabled(self, flag):
cfg = parse_cfg(
["--new_repo_id", "test/merged", "--operation.type", "merge", f"--operation.{flag}", "false"]
)
assert isinstance(cfg.operation, MergeConfig)
assert getattr(cfg.operation, flag) is False
def test_non_merge_requires_repo_id(self):
cfg = parse_cfg(["--operation.type", "delete_episodes"])
with pytest.raises(ValueError, match="--repo_id is required for delete_episodes"):