mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
Merge branch 'main' into feature/add-multitask-dit
This commit is contained in:
@@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch):
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received without modifying RobotClient
|
||||
action_chunks_received = {"count": 0}
|
||||
# Track action chunks received and verify device type
|
||||
action_chunks_received = {"count": 0, "actions_on_cpu": True}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
# Check that all received actions are on CPU
|
||||
if args:
|
||||
for timed_action in args[0]: # args[0] is the list of TimedAction
|
||||
action_tensor = timed_action.get_action()
|
||||
if action_tensor.device.type != "cpu":
|
||||
action_chunks_received["actions_on_cpu"] = False
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
@@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
|
||||
|
||||
This verifies the fix for a bug where image columns were written with a generic
|
||||
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
|
||||
the proper Image() feature type, causing HuggingFace Hub viewer to display
|
||||
raw dict objects instead of image thumbnails.
|
||||
"""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
# Check that parquet files have proper Image schema
|
||||
data_dir = aggr_ds.root / "data"
|
||||
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
|
||||
|
||||
for parquet_file in parquet_files:
|
||||
# Load with HuggingFace datasets to check schema
|
||||
ds = datasets.Dataset.from_parquet(str(parquet_file))
|
||||
|
||||
for image_key in image_keys:
|
||||
feature = ds.features.get(image_key)
|
||||
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
|
||||
assert isinstance(feature, datasets.Image), (
|
||||
f"Image key '{image_key}' should have Image() feature type, "
|
||||
f"but got {type(feature).__name__}: {feature}. "
|
||||
"This indicates image schema was not preserved during aggregation."
|
||||
)
|
||||
|
||||
|
||||
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that image frames are correctly preserved after aggregation."""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
def images_equal(img1, img2):
|
||||
return torch.allclose(img1, img2)
|
||||
|
||||
# Test the section corresponding to the first dataset (ds_0)
|
||||
for i in range(len(ds_0)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test the section corresponding to the second dataset (ds_1)
|
||||
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_1"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
This test specifically verifies that:
|
||||
1. Image-based datasets can be aggregated correctly
|
||||
2. The HuggingFace Image() feature type is preserved in parquet files
|
||||
3. Image data integrity is maintained across aggregation
|
||||
4. Images can be properly decoded after aggregation
|
||||
|
||||
This catches the bug where to_parquet_with_hf_images() was not passing
|
||||
the features schema, causing image columns to be written as generic
|
||||
struct types instead of Image() types.
|
||||
"""
|
||||
ds_0_num_frames = 50
|
||||
ds_1_num_frames = 75
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
|
||||
aggr_root=tmp_path / "image_aggr",
|
||||
)
|
||||
|
||||
# Load the aggregated dataset
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
|
||||
|
||||
# Verify aggregated dataset has image keys
|
||||
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
|
||||
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
|
||||
|
||||
# Run standard aggregation assertions
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
for image_key in aggr_ds.meta.image_keys:
|
||||
img = sample_item[image_key]
|
||||
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
|
||||
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
|
||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1050,7 +1050,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos(tmp_path):
|
||||
def test_convert_image_to_video_dataset(tmp_path):
|
||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
@@ -1071,7 +1071,7 @@ def test_convert_dataset_to_videos(tmp_path):
|
||||
assert "observation.image" in source_dataset.meta.features
|
||||
|
||||
# Convert to video dataset (only first 2 episodes for speed)
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
video_dataset = convert_image_to_video_dataset(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video",
|
||||
@@ -1113,7 +1113,7 @@ def test_convert_dataset_to_videos(tmp_path):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
|
||||
def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
|
||||
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
@@ -1132,7 +1132,7 @@ def test_convert_dataset_to_videos_subset_episodes(tmp_path):
|
||||
# Convert only episode 0 to video (subset of loaded episodes)
|
||||
episode_indices = [0]
|
||||
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
video_dataset = convert_image_to_video_dataset(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video_subset",
|
||||
|
||||
Reference in New Issue
Block a user