Merge branch 'main' into feature/add-multitask-dit

This commit is contained in:
Steven Palma
2026-03-06 14:25:15 +01:00
committed by GitHub
33 changed files with 108 additions and 107 deletions
+7
View File
@@ -143,12 +143,18 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
if policy_name == "vqbet" and DEVICE == "mps":
pytest.skip("VQBet does not support MPS backend")
if policy_name == "act" and "aloha" in ds_repo_id and DEVICE == "mps":
pytest.skip("ACT with aloha has batch mutation issues on MPS")
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
)
train_cfg.policy.device = DEVICE
train_cfg.validate()
# Check that we can make the policy object.
@@ -227,6 +233,7 @@ def test_act_backbone_lr():
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
)
cfg.policy.device = DEVICE
cfg.validate() # Needed for auto-setting some parameters
assert cfg.policy.optimizer_lr == 0.01
+1 -3
View File
@@ -1870,9 +1870,7 @@ class NonCallableStep(ProcessorStep):
def test_construction_rejects_step_without_call():
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
with pytest.raises(
TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_"
):
with pytest.raises(TypeError, match=r"Can't instantiate abstract class NonCallableStep"):
DataProcessorPipeline([NonCallableStep()])
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
+2 -1
View File
@@ -22,8 +22,9 @@ import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.utils.import_utils import is_package_available
from lerobot.utils.utils import auto_select_torch_device
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
TEST_ROBOT_TYPES = []
for robot_type in available_robots: