Merge branch 'main' into feature/add-multitask-dit

This commit is contained in:
Steven Palma
2026-03-06 14:25:15 +01:00
committed by GitHub
33 changed files with 108 additions and 107 deletions
+7
View File
@@ -143,12 +143,18 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
if policy_name == "vqbet" and DEVICE == "mps":
pytest.skip("VQBet does not support MPS backend")
if policy_name == "act" and "aloha" in ds_repo_id and DEVICE == "mps":
pytest.skip("ACT with aloha has batch mutation issues on MPS")
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
)
train_cfg.policy.device = DEVICE
train_cfg.validate()
# Check that we can make the policy object.
@@ -227,6 +233,7 @@ def test_act_backbone_lr():
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
)
cfg.policy.device = DEVICE
cfg.validate() # Needed for auto-setting some parameters
assert cfg.policy.optimizer_lr == 0.01