Files
lerobot/examples/tester.py
T
2025-09-18 14:12:54 +02:00

67 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
REPO_A = "lerobot/pusht"
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
feature_keys_mapping = {
REPO_A: { # pusht (1 camera, 2-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.image": "obs_image.cam_high",
},
REPO_B: { # dual arm (3 cameras, 14-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.images.cam_high": "obs_image.cam_high",
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
},
}
from torchvision.transforms.v2 import Compose, ToImage, Resize
image_tf = Compose([
ToImage(), # converts to tensor if needed
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
])
from torch.utils.data import DataLoader
dataset = MultiLeRobotDataset(
repo_ids=[REPO_A, REPO_B],
image_transforms=image_tf, # ensures same HxW
feature_keys_mapping=feature_keys_mapping,
train_on_all_features=True, # keep union of cameras; zero-fill missing
# optional: override if you want fixed maxima; else inferred:
# max_action_dim=14,
# max_state_dim=14,
max_action_dim=14,
max_state_dim=14,
max_image_dim=224,
ignore_keys=[
"next.*", # drop reward/done/success
"index",
"timestamp",
"videos/*", # drop all video metadata
"observation.effort", # 👈 drop effort everywhere
],
)
breakpoint()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
for _ in range(100):
batch = next(iter(loader))
breakpoint()
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
assert batch["actions"].shape[-1] == 14
assert batch["obs_state"].shape[-1] == 14
assert batch["actions_padding_mask"].shape[-1] == 14
assert batch["obs_state_padding_mask"].shape[-1] == 14
# cameras: all canonical keys exist; pusht will have wrists zero-filled
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
assert cam in batch
assert f"{cam}_is_pad" in batch
# images should all be 3x224x224 (or your transforms size)
img = batch[cam]
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader