mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
378 lines
12 KiB
Python
378 lines
12 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test script for Multi-Task DiT policy.
|
|
|
|
To run tests with GPU on Modal (temporary script):
|
|
modal run run_tests_modal.py
|
|
|
|
To run tests locally:
|
|
python -m pytest tests/policies/test_multi_task_dit_policy.py -v
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
|
|
DiffusionConfig,
|
|
FlowMatchingConfig,
|
|
MultiTaskDiTConfig,
|
|
)
|
|
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
|
from lerobot.utils.random_utils import seeded_context, set_seed
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def set_random_seed():
|
|
seed = 17
|
|
set_seed(seed)
|
|
|
|
|
|
def create_train_batch(
|
|
batch_size: int = 2,
|
|
n_obs_steps: int = 2,
|
|
horizon: int = 16,
|
|
state_dim: int = 10,
|
|
action_dim: int = 10,
|
|
height: int = 224,
|
|
width: int = 224,
|
|
) -> dict[str, Tensor]:
|
|
"""Create a training batch with visual input and text."""
|
|
return {
|
|
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
|
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
|
|
ACTION: torch.randn(batch_size, horizon, action_dim),
|
|
"task": ["pick up the cube"] * batch_size,
|
|
}
|
|
|
|
|
|
def create_observation_batch(
|
|
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
|
|
) -> dict:
|
|
"""Create observation batch for inference for a single timestep."""
|
|
return {
|
|
"observation.state": torch.randn(batch_size, state_dim),
|
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
|
|
"task": ["pick up the red cube"] * batch_size,
|
|
}
|
|
|
|
|
|
def create_config(
|
|
state_dim: int = 10,
|
|
action_dim: int = 10,
|
|
n_obs_steps: int = 2,
|
|
horizon: int = 16,
|
|
n_action_steps: int = 8,
|
|
with_visual: bool = True,
|
|
height: int = 224,
|
|
width: int = 224,
|
|
) -> MultiTaskDiTConfig:
|
|
"""Create a MultiTaskDiT config for testing.
|
|
|
|
Args:
|
|
state_dim: Dimension of state observations
|
|
action_dim: Dimension of actions
|
|
n_obs_steps: Number of observation steps
|
|
horizon: Action prediction horizon
|
|
n_action_steps: Number of action steps to execute
|
|
with_visual: Whether to include visual input (default: True)
|
|
height: Image height (only used if with_visual=True)
|
|
width: Image width (only used if with_visual=True)
|
|
"""
|
|
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
|
|
|
|
if with_visual:
|
|
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
|
|
type=FeatureType.VISUAL, shape=(3, height, width)
|
|
)
|
|
|
|
config = MultiTaskDiTConfig(
|
|
input_features=input_features,
|
|
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
)
|
|
|
|
# Use smaller model for faster tests
|
|
config.transformer.hidden_dim = 128
|
|
config.transformer.num_layers = 2
|
|
config.transformer.num_heads = 4
|
|
|
|
config.validate_features()
|
|
return config
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
|
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
|
"""Test forward pass (training mode)."""
|
|
n_obs_steps = 2
|
|
horizon = 16
|
|
n_action_steps = 8
|
|
|
|
config = create_config(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
)
|
|
|
|
policy = MultiTaskDiTPolicy(config=config)
|
|
policy.train()
|
|
|
|
batch = create_train_batch(
|
|
batch_size=batch_size,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
)
|
|
|
|
# Test forward pass
|
|
loss, _ = policy.forward(batch)
|
|
assert loss is not None
|
|
assert loss.item() is not None
|
|
assert loss.shape == ()
|
|
|
|
# Test backward pass
|
|
loss.backward()
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
|
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
|
"""Test select_action (inference mode)."""
|
|
n_obs_steps = 2
|
|
horizon = 16
|
|
n_action_steps = 8
|
|
|
|
config = create_config(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
)
|
|
|
|
policy = MultiTaskDiTPolicy(config=config)
|
|
policy.eval()
|
|
policy.reset() # Reset queues before inference
|
|
|
|
with torch.no_grad():
|
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
selected_action = policy.select_action(observation_batch)
|
|
assert selected_action.shape == (batch_size, action_dim)
|
|
|
|
|
|
def test_multi_task_dit_policy_diffusion_objective():
|
|
"""Test policy with diffusion objective."""
|
|
batch_size = 2
|
|
state_dim = 10
|
|
action_dim = 10
|
|
n_obs_steps = 2
|
|
horizon = 16
|
|
n_action_steps = 8
|
|
|
|
config = create_config(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
)
|
|
config.objective = DiffusionConfig(
|
|
noise_scheduler_type="DDPM",
|
|
num_train_timesteps=100,
|
|
num_inference_steps=10,
|
|
)
|
|
|
|
policy = MultiTaskDiTPolicy(config=config)
|
|
policy.train()
|
|
|
|
batch = create_train_batch(
|
|
batch_size=batch_size,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
)
|
|
|
|
# Test forward pass
|
|
loss, _ = policy.forward(batch)
|
|
assert loss is not None
|
|
assert loss.item() is not None
|
|
|
|
# Test inference
|
|
policy.eval()
|
|
with torch.no_grad():
|
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
selected_action = policy.select_action(observation_batch)
|
|
assert selected_action.shape == (batch_size, action_dim)
|
|
|
|
|
|
def test_multi_task_dit_policy_flow_matching_objective():
|
|
"""Test policy with flow matching objective."""
|
|
batch_size = 2
|
|
state_dim = 10
|
|
action_dim = 10
|
|
n_obs_steps = 2
|
|
horizon = 16
|
|
n_action_steps = 8
|
|
|
|
config = create_config(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
)
|
|
config.objective = FlowMatchingConfig(
|
|
sigma_min=0.0,
|
|
num_integration_steps=10, # Use fewer steps for faster tests
|
|
integration_method="euler",
|
|
)
|
|
|
|
policy = MultiTaskDiTPolicy(config=config)
|
|
policy.train()
|
|
|
|
batch = create_train_batch(
|
|
batch_size=batch_size,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
)
|
|
|
|
# Test forward pass
|
|
loss, _ = policy.forward(batch)
|
|
assert loss is not None
|
|
assert loss.item() is not None
|
|
|
|
# Test inference
|
|
policy.eval()
|
|
with torch.no_grad():
|
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
selected_action = policy.select_action(observation_batch)
|
|
assert selected_action.shape == (batch_size, action_dim)
|
|
|
|
|
|
def test_multi_task_dit_policy_save_and_load(tmp_path):
|
|
"""Test that the policy can be saved and loaded correctly."""
|
|
root = tmp_path / "test_multi_task_dit_save_and_load"
|
|
|
|
state_dim = 10
|
|
action_dim = 10
|
|
batch_size = 2
|
|
n_obs_steps = 2
|
|
horizon = 16
|
|
n_action_steps = 8
|
|
|
|
config = create_config(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
)
|
|
|
|
policy = MultiTaskDiTPolicy(config=config)
|
|
policy.eval()
|
|
|
|
# Get device before saving
|
|
device = next(policy.parameters()).device
|
|
|
|
policy.save_pretrained(root)
|
|
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
|
|
|
# Explicitly move loaded_policy to the same device
|
|
loaded_policy.to(device)
|
|
loaded_policy.eval()
|
|
|
|
batch = create_train_batch(
|
|
batch_size=batch_size,
|
|
n_obs_steps=n_obs_steps,
|
|
horizon=horizon,
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
)
|
|
|
|
# Move batch to the same device as the policy
|
|
for key in batch:
|
|
if isinstance(batch[key], torch.Tensor):
|
|
batch[key] = batch[key].to(device)
|
|
|
|
with torch.no_grad():
|
|
with seeded_context(12):
|
|
# Collect policy values before saving
|
|
loss, _ = policy.forward(batch)
|
|
|
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
# Move observation batch to device
|
|
for key in observation_batch:
|
|
if isinstance(observation_batch[key], torch.Tensor):
|
|
observation_batch[key] = observation_batch[key].to(device)
|
|
actions = policy.select_action(observation_batch)
|
|
|
|
with seeded_context(12):
|
|
# Collect policy values after loading
|
|
loaded_loss, _ = loaded_policy.forward(batch)
|
|
|
|
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
# Move observation batch to device
|
|
for key in loaded_observation_batch:
|
|
if isinstance(loaded_observation_batch[key], torch.Tensor):
|
|
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
|
|
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
|
|
|
# Compare state dicts
|
|
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
|
for k in policy.state_dict():
|
|
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
|
|
|
# Compare values before and after saving and loading
|
|
assert torch.allclose(loss, loaded_loss)
|
|
assert torch.allclose(actions, loaded_actions)
|
|
|
|
|
|
def test_multi_task_dit_policy_get_optim_params():
|
|
"""Test that the policy returns correct optimizer parameter groups."""
|
|
config = create_config(
|
|
state_dim=10,
|
|
action_dim=10,
|
|
n_obs_steps=2,
|
|
horizon=16,
|
|
n_action_steps=8,
|
|
)
|
|
|
|
policy = MultiTaskDiTPolicy(config=config)
|
|
param_groups = policy.get_optim_params()
|
|
|
|
# Should have 2 parameter groups: non-vision and vision encoder
|
|
assert len(param_groups) == 2
|
|
|
|
# First group is non-vision params (no lr specified, will use default)
|
|
assert "params" in param_groups[0]
|
|
assert len(param_groups[0]["params"]) > 0
|
|
|
|
# Second group is vision encoder params with different lr
|
|
assert "params" in param_groups[1]
|
|
assert "lr" in param_groups[1]
|
|
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
|
|
assert param_groups[1]["lr"] == expected_lr
|