mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
|
||
|
||
REPO_A = "lerobot/pusht"
|
||
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
|
||
|
||
feature_keys_mapping = {
|
||
REPO_A: { # pusht (1 camera, 2-dim)
|
||
"action": "actions",
|
||
"observation.state": "obs_state",
|
||
"observation.image": "obs_image.cam_high",
|
||
},
|
||
REPO_B: { # dual arm (3 cameras, 14-dim)
|
||
"action": "actions",
|
||
"observation.state": "obs_state",
|
||
"observation.images.cam_high": "obs_image.cam_high",
|
||
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
|
||
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
|
||
},
|
||
}
|
||
|
||
from torchvision.transforms.v2 import Compose, ToImage, Resize
|
||
image_tf = Compose([
|
||
ToImage(), # converts to tensor if needed
|
||
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
|
||
])
|
||
|
||
from torch.utils.data import DataLoader
|
||
|
||
dataset = MultiLeRobotDataset(
|
||
repo_ids=[REPO_A, REPO_B],
|
||
image_transforms=image_tf, # ensures same HxW
|
||
feature_keys_mapping=feature_keys_mapping,
|
||
train_on_all_features=True, # keep union of cameras; zero-fill missing
|
||
# optional: override if you want fixed maxima; else inferred:
|
||
# max_action_dim=14,
|
||
# max_state_dim=14,
|
||
max_action_dim=14,
|
||
max_state_dim=14,
|
||
max_image_dim=224,
|
||
ignore_keys=[
|
||
"next.*", # drop reward/done/success
|
||
"index",
|
||
"timestamp",
|
||
"videos/*", # drop all video metadata
|
||
"observation.effort", # 👈 drop effort everywhere
|
||
],
|
||
)
|
||
breakpoint()
|
||
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
|
||
for _ in range(100):
|
||
batch = next(iter(loader))
|
||
|
||
breakpoint()
|
||
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
|
||
assert batch["actions"].shape[-1] == 14
|
||
assert batch["obs_state"].shape[-1] == 14
|
||
assert batch["actions_padding_mask"].shape[-1] == 14
|
||
assert batch["obs_state_padding_mask"].shape[-1] == 14
|
||
|
||
# cameras: all canonical keys exist; pusht will have wrists zero-filled
|
||
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
|
||
assert cam in batch
|
||
assert f"{cam}_is_pad" in batch
|
||
# images should all be 3x224x224 (or your transform’s size)
|
||
img = batch[cam]
|
||
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader
|