mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
Feature/add multitask diffusion transformer policy implementation (#2545)
* Add multitask diffusion transformer policy Add multitask diffusion transformer policy * expand the observation encoder to support differnt size encoders for vision and text * add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs * update readme and citations for multitask dit policy * remove dino vision encoder and simplify text and vision encoders by removing inheritance structure * adjust factory comment * update docstring for multitask dit policy processor file * simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives * add references to the modeling file comments * merge all modules files into the main modeling file * add torch.no_grad decorators * split up select action return statement * remove redundant asserts * add tutorial to training with multi_task_dit * fix bugs when testing on hardware * remove environment state conditioning * update typo in test instruction comment * add processor tests to multitask dit tests * move policy to top of file * use constants for indexing into batches and remove env state references * remove the base classes since we don't need to be able to extend * fix nit formatting in generate actions fcn * reformat and clean up tutorial for multitask dit policy * add more descriptions and depth to multitask dit tutorial * note origins of each training objective * rename config param for multiple vision encoders * refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit * add multitask dit to toc for docs * add conditional transformers import to match all other policies that use transformers lib * add test handling for multitask dit when transformers isnt available * skip tests without transformers * remove cropping of images smaller than the crop size * add kwargs arg to multitask dit constructor * add wallx dep conflict management for multitask dit policy * use hyphens for cleanliness in pyproject.toml * add conflict management to pyproject toml for pi conflict for mtdp as well * update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time * revert fast tests edits * update docs and readme files, fixing some typos and adding multitask dit to readme * chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy * chore(dependencies): bump pi0 family to transformers v5 * chore(dependencies): bump wall x to transformers v5 * chore(dependencies): bump gr00t to transformers v5 * chore(style): fix pre-commit * fix(policy): xvla forced_bos_token missing * test(rl): skip ci tests for resnet10 * Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun <skieyfly@gmail.com> * chore: fix XVLA in transformers v5 (#3006) * test(policies): enable wall x CI testing * style(test): pre-commit check * style(test): pre-commit --------- Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: Yufei Sun <skieyfly@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: E402
|
||||
|
||||
"""Test script for Multi-Task DiT policy.
|
||||
|
||||
To run tests locally:
|
||||
python -m pytest tests/policies/multi_task_dit/test_multi_task_dit.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local transformers installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
seed = 17
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
def create_train_batch(
|
||||
batch_size: int = 2,
|
||||
n_obs_steps: int = 2,
|
||||
horizon: int = 16,
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 10,
|
||||
height: int = 224,
|
||||
width: int = 224,
|
||||
) -> dict[str, Tensor]:
|
||||
"""Create a training batch with visual input and text."""
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
|
||||
ACTION: torch.randn(batch_size, horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(
|
||||
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
|
||||
) -> dict:
|
||||
"""Create observation batch for inference for a single timestep."""
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
|
||||
"task": ["pick up the red cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def create_config(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 10,
|
||||
n_obs_steps: int = 2,
|
||||
horizon: int = 16,
|
||||
n_action_steps: int = 8,
|
||||
with_visual: bool = True,
|
||||
height: int = 224,
|
||||
width: int = 224,
|
||||
) -> MultiTaskDiTConfig:
|
||||
"""Create a MultiTaskDiT config for testing.
|
||||
|
||||
Args:
|
||||
state_dim: Dimension of state observations
|
||||
action_dim: Dimension of actions
|
||||
n_obs_steps: Number of observation steps
|
||||
horizon: Action prediction horizon
|
||||
n_action_steps: Number of action steps to execute
|
||||
with_visual: Whether to include visual input (default: True)
|
||||
height: Image height (only used if with_visual=True)
|
||||
width: Image width (only used if with_visual=True)
|
||||
"""
|
||||
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
|
||||
|
||||
if with_visual:
|
||||
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(3, height, width)
|
||||
)
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use smaller model for faster tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
"""Test forward pass (training mode)."""
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
assert loss.shape == ()
|
||||
|
||||
# Test backward pass
|
||||
loss.backward()
|
||||
|
||||
|
||||
def test_multi_task_dit_pre_post_processors():
|
||||
"""Test pre and post processors for Multi-Task DiT policy."""
|
||||
state_dim = 10
|
||||
action_dim = 8
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=8,
|
||||
)
|
||||
config.device = "cpu"
|
||||
|
||||
# Set normalization mode to match the stats we're providing
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Create dataset stats for normalization
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(state_dim),
|
||||
"std": torch.ones(state_dim),
|
||||
},
|
||||
"action": {
|
||||
"min": torch.full((action_dim,), -1.0),
|
||||
"max": torch.ones(action_dim),
|
||||
},
|
||||
}
|
||||
|
||||
# Create processors
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
# Test preprocessor with sample data
|
||||
batch = {
|
||||
"observation.state": torch.randn(state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
|
||||
ACTION: torch.randn(action_dim),
|
||||
"task": "pick up the cube",
|
||||
}
|
||||
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Check that data is batched
|
||||
assert processed_batch["observation.state"].shape == (1, state_dim)
|
||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
||||
assert processed_batch[ACTION].shape == (1, action_dim)
|
||||
# Check that task text was tokenized
|
||||
assert OBS_LANGUAGE_TOKENS in processed_batch
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
|
||||
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
|
||||
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
|
||||
|
||||
# Check that data is on correct device
|
||||
assert processed_batch["observation.state"].device.type == "cpu"
|
||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu"
|
||||
assert processed_batch[ACTION].device.type == "cpu"
|
||||
|
||||
# Test postprocessor with sample action (PolicyAction is just a torch.Tensor)
|
||||
action = torch.randn(1, action_dim)
|
||||
processed_action = postprocessor(action)
|
||||
|
||||
# Check that action is unnormalized and on CPU
|
||||
assert processed_action.shape == (1, action_dim)
|
||||
assert processed_action.device.type == "cpu"
|
||||
|
||||
|
||||
def test_multi_task_dit_pre_post_processors_normalization():
|
||||
"""Test that normalization and unnormalization work correctly with simple sanity check numbers."""
|
||||
state_dim = 3
|
||||
action_dim = 2
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=2,
|
||||
horizon=16,
|
||||
n_action_steps=8,
|
||||
)
|
||||
config.device = "cpu"
|
||||
|
||||
# Set normalization mode to match the stats we're providing
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Use simple stats that will actually transform the values
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.full((state_dim,), 5.0),
|
||||
"std": torch.full((state_dim,), 2.0),
|
||||
},
|
||||
"action": {
|
||||
"min": torch.zeros(action_dim),
|
||||
"max": torch.full((action_dim,), 2.0),
|
||||
},
|
||||
}
|
||||
|
||||
# Create processors
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
# Use simple input values
|
||||
input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0]
|
||||
input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0]
|
||||
|
||||
batch = {
|
||||
"observation.state": input_state,
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
|
||||
ACTION: input_action,
|
||||
"task": "test task",
|
||||
}
|
||||
|
||||
# Process through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# State normalization: (x - mean) / std
|
||||
expected_normalized_state = torch.tensor([1.0, 0.0, -1.0])
|
||||
assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5)
|
||||
|
||||
# Action normalization: (x - min) / (max - min) * 2 - 1
|
||||
expected_normalized_action = torch.tensor([0.0, 1.0])
|
||||
assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5)
|
||||
|
||||
# Test unnormalization: should recover original values
|
||||
normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension
|
||||
unnormalized_action = postprocessor(normalized_action_tensor)
|
||||
|
||||
# Should recover original action values
|
||||
assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
"""Test select_action (inference mode)."""
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.reset() # Reset queues before inference
|
||||
|
||||
# Create processors - use IDENTITY normalization when no stats provided
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
selected_action = policy.select_action(processed_obs)
|
||||
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||
processed_action = postprocessor(selected_action)
|
||||
assert processed_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_diffusion_objective():
|
||||
"""Test policy with diffusion objective."""
|
||||
batch_size = 2
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use diffusion objective
|
||||
objective="diffusion",
|
||||
noise_scheduler_type="DDPM",
|
||||
num_train_timesteps=100,
|
||||
num_inference_steps=10,
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
# Use IDENTITY normalization when no stats provided
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
selected_action = policy.select_action(processed_obs)
|
||||
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||
processed_action = postprocessor(selected_action)
|
||||
assert processed_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_flow_matching_objective():
|
||||
"""Test policy with flow matching objective."""
|
||||
batch_size = 2
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use flow matching objective
|
||||
objective="flow_matching",
|
||||
sigma_min=0.0,
|
||||
num_integration_steps=10, # Fewer steps for faster tests
|
||||
integration_method="euler",
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
# Use IDENTITY normalization when no stats provided
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
selected_action = policy.select_action(processed_obs)
|
||||
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||
processed_action = postprocessor(selected_action)
|
||||
assert processed_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded correctly."""
|
||||
root = tmp_path / "test_multi_task_dit_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
# Get device before saving
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
||||
|
||||
# Explicitly move loaded_policy to the same device
|
||||
loaded_policy.to(device)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Move batch to the same device as the policy
|
||||
for key in processed_batch:
|
||||
if isinstance(processed_batch[key], torch.Tensor):
|
||||
processed_batch[key] = processed_batch[key].to(device)
|
||||
# Collect policy values before saving
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
actions = policy.select_action(processed_obs)
|
||||
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Collect policy values after loading
|
||||
loaded_loss, _ = loaded_policy.forward(processed_batch)
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
processed_obs = preprocessor(loaded_observation_batch)
|
||||
loaded_actions = loaded_policy.select_action(processed_obs)
|
||||
|
||||
# Compare state dicts
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
assert torch.allclose(loss, loaded_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_get_optim_params():
|
||||
"""Test that the policy returns correct optimizer parameter groups."""
|
||||
config = create_config(
|
||||
state_dim=10,
|
||||
action_dim=10,
|
||||
n_obs_steps=2,
|
||||
horizon=16,
|
||||
n_action_steps=8,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
param_groups = policy.get_optim_params()
|
||||
|
||||
# Should have 2 parameter groups: non-vision and vision encoder
|
||||
assert len(param_groups) == 2
|
||||
|
||||
# First group is non-vision params (no lr specified, will use default)
|
||||
assert "params" in param_groups[0]
|
||||
assert len(param_groups[0]["params"]) > 0
|
||||
|
||||
# Second group is vision encoder params with different lr
|
||||
assert "params" in param_groups[1]
|
||||
assert "lr" in param_groups[1]
|
||||
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
|
||||
assert param_groups[1]["lr"] == expected_lr
|
||||
Reference in New Issue
Block a user