mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Merge branch 'pr-1451' into danaaubakirova/25_06_2025
This commit is contained in:
@@ -394,37 +394,58 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
# @pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
|
||||
# dummy padding dimensions (simulate training setup)
|
||||
MAX_ACTION_DIM = 14
|
||||
MAX_STATE_DIM = 30
|
||||
MAX_NUM_IMAGES = 3
|
||||
MAX_IMAGE_DIM = 224
|
||||
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_ids,
|
||||
max_action_dim=MAX_ACTION_DIM,
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_num_images=MAX_NUM_IMAGES,
|
||||
max_image_dim=MAX_IMAGE_DIM,
|
||||
)
|
||||
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
for expected_dataset_index, sub_item, multi_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
dataset_index = multi_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
# we ignore padding_mask and dataset_index keys in multi_item
|
||||
extra_keys = {k for k in multi_item if "padding_mask" in k}
|
||||
filtered_multi_keys = set(multi_item.keys()) - extra_keys
|
||||
assert set(sub_item.keys()) == filtered_multi_keys, f"mismatch in keys"
|
||||
|
||||
for k in sub_item:
|
||||
if k not in multi_item:
|
||||
continue
|
||||
v1, v2 = sub_item[k], multi_item[k]
|
||||
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
|
||||
else:
|
||||
assert v1 == v2, f"value mismatch on key: {k}"
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
|
||||
Reference in New Issue
Block a user