mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
add first commit
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
|
||||
REPO_A = "lerobot/pusht"
|
||||
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
|
||||
|
||||
feature_keys_mapping = {
|
||||
REPO_A: { # pusht (1 camera, 2-dim)
|
||||
"action": "actions",
|
||||
"observation.state": "obs_state",
|
||||
"observation.image": "obs_image.cam_high",
|
||||
},
|
||||
REPO_B: { # dual arm (3 cameras, 14-dim)
|
||||
"action": "actions",
|
||||
"observation.state": "obs_state",
|
||||
"observation.images.cam_high": "obs_image.cam_high",
|
||||
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
|
||||
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
|
||||
},
|
||||
}
|
||||
|
||||
from torchvision.transforms.v2 import Compose, ToImage, Resize
|
||||
image_tf = Compose([
|
||||
ToImage(), # converts to tensor if needed
|
||||
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
|
||||
])
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_ids=[REPO_A, REPO_B],
|
||||
image_transforms=image_tf, # ensures same HxW
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
train_on_all_features=True, # keep union of cameras; zero-fill missing
|
||||
# optional: override if you want fixed maxima; else inferred:
|
||||
# max_action_dim=14,
|
||||
# max_state_dim=14,
|
||||
max_action_dim=14,
|
||||
max_state_dim=14,
|
||||
max_image_dim=224,
|
||||
ignore_keys=[
|
||||
"next.*", # drop reward/done/success
|
||||
"index",
|
||||
"timestamp",
|
||||
"videos/*", # drop all video metadata
|
||||
"observation.effort", # 👈 drop effort everywhere
|
||||
],
|
||||
)
|
||||
breakpoint()
|
||||
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
|
||||
for _ in range(100):
|
||||
batch = next(iter(loader))
|
||||
|
||||
breakpoint()
|
||||
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
|
||||
assert batch["actions"].shape[-1] == 14
|
||||
assert batch["obs_state"].shape[-1] == 14
|
||||
assert batch["actions_padding_mask"].shape[-1] == 14
|
||||
assert batch["obs_state_padding_mask"].shape[-1] == 14
|
||||
|
||||
# cameras: all canonical keys exist; pusht will have wrists zero-filled
|
||||
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
|
||||
assert cam in batch
|
||||
assert f"{cam}_is_pad" in batch
|
||||
# images should all be 3x224x224 (or your transform’s size)
|
||||
img = batch[cam]
|
||||
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader
|
||||
@@ -0,0 +1,16 @@
|
||||
# storage / caches
|
||||
RAID=/raid/jade
|
||||
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
|
||||
export HF_HOME=$RAID/.cache/huggingface
|
||||
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
|
||||
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
|
||||
export WANDB_CACHE_DIR=$RAID/.cache/wandb
|
||||
export TMPDIR=$RAID/.cache/tmp
|
||||
mkdir -p $TMPDIR
|
||||
export WANDB_MODE=offline
|
||||
# export HF_DATASETS_OFFLINE=1
|
||||
# export HF_HUB_OFFLINE=1
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
export MUJOCO_GL=egl
|
||||
|
||||
python examples/tester.py
|
||||
Reference in New Issue
Block a user