Merge branch 'main' into feature/add-multitask-dit

This commit is contained in:
Bryson Jones
2026-01-20 07:43:57 -08:00
committed by GitHub
18 changed files with 967 additions and 407 deletions
+8 -2
View File
@@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch):
client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received without modifying RobotClient
action_chunks_received = {"count": 0}
# Track action chunks received and verify device type
action_chunks_received = {"count": 0, "actions_on_cpu": True}
original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1
# Check that all received actions are on CPU
if args:
for timed_action in args[0]: # args[0] is the list of TimedAction
action_tensor = timed_action.get_action()
if action_tensor.device.type != "cpu":
action_chunks_received["actions_on_cpu"] = False
return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
+145
View File
@@ -16,6 +16,7 @@
from unittest.mock import patch
import datasets
import torch
from lerobot.datasets.aggregate import aggregate_datasets
@@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
def assert_image_schema_preserved(aggr_ds):
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
This verifies the fix for a bug where image columns were written with a generic
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
the proper Image() feature type, causing HuggingFace Hub viewer to display
raw dict objects instead of image thumbnails.
"""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return
# Check that parquet files have proper Image schema
data_dir = aggr_ds.root / "data"
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
for parquet_file in parquet_files:
# Load with HuggingFace datasets to check schema
ds = datasets.Dataset.from_parquet(str(parquet_file))
for image_key in image_keys:
feature = ds.features.get(image_key)
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
assert isinstance(feature, datasets.Image), (
f"Image key '{image_key}' should have Image() feature type, "
f"but got {type(feature).__name__}: {feature}. "
"This indicates image schema was not preserved during aggregation."
)
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that image frames are correctly preserved after aggregation."""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return
def images_equal(img1, img2):
return torch.allclose(img1, img2)
# Test the section corresponding to the first dataset (ds_0)
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_0"
)
# Test the section corresponding to the second dataset (ds_1)
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_1"
)
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
This test specifically verifies that:
1. Image-based datasets can be aggregated correctly
2. The HuggingFace Image() feature type is preserved in parquet files
3. Image data integrity is maintained across aggregation
4. Images can be properly decoded after aggregation
This catches the bug where to_parquet_with_hf_images() was not passing
the features schema, causing image columns to be written as generic
struct types instead of Image() types.
"""
ds_0_num_frames = 50
ds_1_num_frames = 75
ds_0_num_episodes = 2
ds_1_num_episodes = 3
# Create two image-based datasets (use_videos=False)
ds_0 = lerobot_dataset_factory(
root=tmp_path / "image_0",
repo_id=f"{DUMMY_REPO_ID}_image_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
use_videos=False, # Image-based dataset
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "image_1",
repo_id=f"{DUMMY_REPO_ID}_image_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
use_videos=False, # Image-based dataset
)
# Verify source datasets have image keys
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
# Aggregate the datasets
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
aggr_root=tmp_path / "image_aggr",
)
# Load the aggregated dataset
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
# Verify aggregated dataset has image keys
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
# Run standard aggregation assertions
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
# Image-specific assertions
assert_image_schema_preserved(aggr_ds)
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
# Verify images can be accessed and have correct shape
sample_item = aggr_ds[0]
for image_key in aggr_ds.meta.image_keys:
img = sample_item[image_key]
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
assert_dataset_iteration_works(aggr_ds)
+5 -5
View File
@@ -29,7 +29,7 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
@pytest.fixture
@@ -1050,7 +1050,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert "reward" in modified_dataset.meta.features
def test_convert_dataset_to_videos(tmp_path):
def test_convert_image_to_video_dataset(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -1071,7 +1071,7 @@ def test_convert_dataset_to_videos(tmp_path):
assert "observation.image" in source_dataset.meta.features
# Convert to video dataset (only first 2 episodes for speed)
video_dataset = convert_dataset_to_videos(
video_dataset = convert_image_to_video_dataset(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video",
@@ -1113,7 +1113,7 @@ def test_convert_dataset_to_videos(tmp_path):
shutil.rmtree(output_dir)
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -1132,7 +1132,7 @@ def test_convert_dataset_to_videos_subset_episodes(tmp_path):
# Convert only episode 0 to video (subset of loaded episodes)
episode_indices = [0]
video_dataset = convert_dataset_to_videos(
video_dataset = convert_image_to_video_dataset(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video_subset",