mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
add processor tests to multitask dit tests
This commit is contained in:
@@ -24,9 +24,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||||
|
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||||
|
make_multi_task_dit_pre_post_processors,
|
||||||
|
)
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||||
|
|
||||||
@@ -147,6 +150,144 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_task_dit_pre_post_processors():
|
||||||
|
"""Test pre and post processors for Multi-Task DiT policy."""
|
||||||
|
state_dim = 10
|
||||||
|
action_dim = 8
|
||||||
|
n_obs_steps = 2
|
||||||
|
horizon = 16
|
||||||
|
|
||||||
|
config = create_config(
|
||||||
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
horizon=horizon,
|
||||||
|
n_action_steps=8,
|
||||||
|
)
|
||||||
|
config.device = "cpu"
|
||||||
|
|
||||||
|
# Set normalization mode to match the stats we're providing
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create dataset stats for normalization
|
||||||
|
dataset_stats = {
|
||||||
|
"observation.state": {
|
||||||
|
"mean": torch.zeros(state_dim),
|
||||||
|
"std": torch.ones(state_dim),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"min": torch.full((action_dim,), -1.0),
|
||||||
|
"max": torch.ones(action_dim),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create processors
|
||||||
|
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
|
||||||
|
config=config, dataset_stats=dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test preprocessor with sample data
|
||||||
|
batch = {
|
||||||
|
"observation.state": torch.randn(state_dim),
|
||||||
|
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
|
||||||
|
ACTION: torch.randn(action_dim),
|
||||||
|
"task": "pick up the cube",
|
||||||
|
}
|
||||||
|
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
|
|
||||||
|
# Check that data is batched
|
||||||
|
assert processed_batch["observation.state"].shape == (1, state_dim)
|
||||||
|
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
||||||
|
assert processed_batch[ACTION].shape == (1, action_dim)
|
||||||
|
assert "task" in processed_batch
|
||||||
|
|
||||||
|
# Check that data is on correct device
|
||||||
|
assert processed_batch["observation.state"].device.type == "cpu"
|
||||||
|
assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu"
|
||||||
|
assert processed_batch[ACTION].device.type == "cpu"
|
||||||
|
|
||||||
|
# Test postprocessor with sample action (PolicyAction is just a torch.Tensor)
|
||||||
|
action = torch.randn(1, action_dim)
|
||||||
|
processed_action = postprocessor(action)
|
||||||
|
|
||||||
|
# Check that action is unnormalized and on CPU
|
||||||
|
assert processed_action.shape == (1, action_dim)
|
||||||
|
assert processed_action.device.type == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_task_dit_pre_post_processors_normalization():
|
||||||
|
"""Test that normalization and unnormalization work correctly with simple sanity check numbers."""
|
||||||
|
state_dim = 3
|
||||||
|
action_dim = 2
|
||||||
|
|
||||||
|
config = create_config(
|
||||||
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
n_obs_steps=2,
|
||||||
|
horizon=16,
|
||||||
|
n_action_steps=8,
|
||||||
|
)
|
||||||
|
config.device = "cpu"
|
||||||
|
|
||||||
|
# Set normalization mode to match the stats we're providing
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use simple stats that will actually transform the values
|
||||||
|
dataset_stats = {
|
||||||
|
"observation.state": {
|
||||||
|
"mean": torch.full((state_dim,), 5.0),
|
||||||
|
"std": torch.full((state_dim,), 2.0),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"min": torch.zeros(action_dim),
|
||||||
|
"max": torch.full((action_dim,), 2.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create processors
|
||||||
|
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
|
||||||
|
config=config, dataset_stats=dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use simple input values
|
||||||
|
input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0]
|
||||||
|
input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"observation.state": input_state,
|
||||||
|
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
|
||||||
|
ACTION: input_action,
|
||||||
|
"task": "test task",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process through preprocessor
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
|
|
||||||
|
# State normalization: (x - mean) / std
|
||||||
|
expected_normalized_state = torch.tensor([1.0, 0.0, -1.0])
|
||||||
|
assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5)
|
||||||
|
|
||||||
|
# Action normalization: (x - min) / (max - min) * 2 - 1
|
||||||
|
expected_normalized_action = torch.tensor([0.0, 1.0])
|
||||||
|
assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5)
|
||||||
|
|
||||||
|
# Test unnormalization: should recover original values
|
||||||
|
normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension
|
||||||
|
unnormalized_action = postprocessor(normalized_action_tensor)
|
||||||
|
|
||||||
|
# Should recover original action values
|
||||||
|
assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||||
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||||
"""Test select_action (inference mode)."""
|
"""Test select_action (inference mode)."""
|
||||||
@@ -166,10 +307,22 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
policy.reset() # Reset queues before inference
|
policy.reset() # Reset queues before inference
|
||||||
|
|
||||||
|
# Create processors - use IDENTITY normalization when no stats provided
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||||
selected_action = policy.select_action(observation_batch)
|
# Process observation through preprocessor
|
||||||
assert selected_action.shape == (batch_size, action_dim)
|
processed_obs = preprocessor(observation_batch)
|
||||||
|
selected_action = policy.select_action(processed_obs)
|
||||||
|
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||||
|
processed_action = postprocessor(selected_action)
|
||||||
|
assert processed_action.shape == (batch_size, action_dim)
|
||||||
|
|
||||||
|
|
||||||
def test_multi_task_dit_policy_diffusion_objective():
|
def test_multi_task_dit_policy_diffusion_objective():
|
||||||
@@ -222,10 +375,21 @@ def test_multi_task_dit_policy_diffusion_objective():
|
|||||||
|
|
||||||
# Test inference
|
# Test inference
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
# Use IDENTITY normalization when no stats provided
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||||
selected_action = policy.select_action(observation_batch)
|
# Process observation through preprocessor
|
||||||
assert selected_action.shape == (batch_size, action_dim)
|
processed_obs = preprocessor(observation_batch)
|
||||||
|
selected_action = policy.select_action(processed_obs)
|
||||||
|
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||||
|
processed_action = postprocessor(selected_action)
|
||||||
|
assert processed_action.shape == (batch_size, action_dim)
|
||||||
|
|
||||||
|
|
||||||
def test_multi_task_dit_policy_flow_matching_objective():
|
def test_multi_task_dit_policy_flow_matching_objective():
|
||||||
@@ -278,10 +442,21 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
|||||||
|
|
||||||
# Test inference
|
# Test inference
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
# Use IDENTITY normalization when no stats provided
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||||
selected_action = policy.select_action(observation_batch)
|
# Process observation through preprocessor
|
||||||
assert selected_action.shape == (batch_size, action_dim)
|
processed_obs = preprocessor(observation_batch)
|
||||||
|
selected_action = policy.select_action(processed_obs)
|
||||||
|
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||||
|
processed_action = postprocessor(selected_action)
|
||||||
|
assert processed_action.shape == (batch_size, action_dim)
|
||||||
|
|
||||||
|
|
||||||
def test_multi_task_dit_policy_save_and_load(tmp_path):
|
def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user