simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

This commit is contained in:
Bryson Jones
2025-12-10 11:45:59 -08:00
parent cdacc090cd
commit 103230c64c
7 changed files with 242 additions and 454 deletions
@@ -28,11 +28,7 @@ import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
DiffusionConfig,
FlowMatchingConfig,
MultiTaskDiTConfig,
)
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -108,13 +104,12 @@ def create_config(
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use smaller model for faster tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
# Use smaller model for faster tests
config.transformer.hidden_dim = 128
config.transformer.num_layers = 2
config.transformer.num_heads = 4
config.validate_features()
return config
@@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective():
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
config.objective = DiffusionConfig(
# Use diffusion objective
objective="diffusion",
noise_scheduler_type="DDPM",
num_train_timesteps=100,
num_inference_steps=10,
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
@@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective():
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
config.objective = FlowMatchingConfig(
# Use flow matching objective
objective="flow_matching",
sigma_min=0.0,
num_integration_steps=10, # Use fewer steps for faster tests
num_integration_steps=10, # Fewer steps for faster tests
integration_method="euler",
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
@@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params():
# Second group is vision encoder params with different lr
assert "params" in param_groups[1]
assert "lr" in param_groups[1]
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
assert param_groups[1]["lr"] == expected_lr