Merge branch 'main' into feat/robotwin-benchmark

This commit is contained in:
Pepijn
2026-04-16 18:57:39 +02:00
committed by GitHub
42 changed files with 1581 additions and 423 deletions
@@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train()
# Use preprocessor to handle tokenization
@@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.eval()
policy.reset() # Reset queues before inference
@@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective():
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train()
# Use preprocessor to handle tokenization
@@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective():
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train()
# Use preprocessor to handle tokenization
@@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.eval()
# Get device before saving
device = next(policy.parameters()).device
policy.save_pretrained(root)
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
# Explicitly move loaded_policy to the same device
loaded_policy.to(device)
loaded_policy.to(config.device)
loaded_policy.eval()
batch = create_train_batch(
@@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving
loss, _ = policy.forward(processed_batch)
@@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params():
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
param_groups = policy.get_optim_params()
# Should have 2 parameter groups: non-vision and vision encoder
@@ -18,6 +18,11 @@ from unittest.mock import MagicMock, patch
import pytest
from lerobot.utils.import_utils import is_package_available
if not is_package_available("reachy2_sdk"):
pytest.skip("reachy2_sdk not available", allow_module_level=True)
from lerobot.teleoperators.reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,