diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py index 42fd1a131..c12dcb2fb 100644 --- a/tests/policies/multi_task_dit/test_multi_task_dit.py +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -24,9 +24,12 @@ import pytest import torch from torch import Tensor -from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy +from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + make_multi_task_dit_pre_post_processors, +) from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed @@ -147,6 +150,144 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d loss.backward() +def test_multi_task_dit_pre_post_processors(): + """Test pre and post processors for Multi-Task DiT policy.""" + state_dim = 10 + action_dim = 8 + n_obs_steps = 2 + horizon = 16 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=8, + ) + config.device = "cpu" + + # Set normalization mode to match the stats we're providing + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats + "ACTION": NormalizationMode.MIN_MAX, + } + + # Create dataset stats for normalization + dataset_stats = { + "observation.state": { + "mean": torch.zeros(state_dim), + "std": torch.ones(state_dim), + }, + "action": { + "min": torch.full((action_dim,), -1.0), + "max": torch.ones(action_dim), + }, + } + + # Create processors + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors( + config=config, dataset_stats=dataset_stats + ) + + # Test preprocessor with sample data + batch = { + "observation.state": torch.randn(state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224), + ACTION: torch.randn(action_dim), + "task": "pick up the cube", + } + + processed_batch = preprocessor(batch) + + # Check that data is batched + assert processed_batch["observation.state"].shape == (1, state_dim) + assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224) + assert processed_batch[ACTION].shape == (1, action_dim) + assert "task" in processed_batch + + # Check that data is on correct device + assert processed_batch["observation.state"].device.type == "cpu" + assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu" + assert processed_batch[ACTION].device.type == "cpu" + + # Test postprocessor with sample action (PolicyAction is just a torch.Tensor) + action = torch.randn(1, action_dim) + processed_action = postprocessor(action) + + # Check that action is unnormalized and on CPU + assert processed_action.shape == (1, action_dim) + assert processed_action.device.type == "cpu" + + +def test_multi_task_dit_pre_post_processors_normalization(): + """Test that normalization and unnormalization work correctly with simple sanity check numbers.""" + state_dim = 3 + action_dim = 2 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=2, + horizon=16, + n_action_steps=8, + ) + config.device = "cpu" + + # Set normalization mode to match the stats we're providing + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats + "ACTION": NormalizationMode.MIN_MAX, + } + + # Use simple stats that will actually transform the values + dataset_stats = { + "observation.state": { + "mean": torch.full((state_dim,), 5.0), + "std": torch.full((state_dim,), 2.0), + }, + "action": { + "min": torch.zeros(action_dim), + "max": torch.full((action_dim,), 2.0), + }, + } + + # Create processors + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors( + config=config, dataset_stats=dataset_stats + ) + + # Use simple input values + input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0] + input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0] + + batch = { + "observation.state": input_state, + f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224), + ACTION: input_action, + "task": "test task", + } + + # Process through preprocessor + processed_batch = preprocessor(batch) + + # State normalization: (x - mean) / std + expected_normalized_state = torch.tensor([1.0, 0.0, -1.0]) + assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5) + + # Action normalization: (x - min) / (max - min) * 2 - 1 + expected_normalized_action = torch.tensor([0.0, 1.0]) + assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5) + + # Test unnormalization: should recover original values + normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension + unnormalized_action = postprocessor(normalized_action_tensor) + + # Should recover original action values + assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4) + + @pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int): """Test select_action (inference mode).""" @@ -166,10 +307,22 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac policy.eval() policy.reset() # Reset queues before inference + # Create processors - use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + with torch.no_grad(): observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) def test_multi_task_dit_policy_diffusion_objective(): @@ -222,10 +375,21 @@ def test_multi_task_dit_policy_diffusion_objective(): # Test inference policy.eval() + # Use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) with torch.no_grad(): observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) def test_multi_task_dit_policy_flow_matching_objective(): @@ -278,10 +442,21 @@ def test_multi_task_dit_policy_flow_matching_objective(): # Test inference policy.eval() + # Use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) with torch.no_grad(): observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) def test_multi_task_dit_policy_save_and_load(tmp_path):