mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives
This commit is contained in:
@@ -28,11 +28,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
|
||||
DiffusionConfig,
|
||||
FlowMatchingConfig,
|
||||
MultiTaskDiTConfig,
|
||||
)
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
@@ -108,13 +104,12 @@ def create_config(
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use smaller model for faster tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
|
||||
# Use smaller model for faster tests
|
||||
config.transformer.hidden_dim = 128
|
||||
config.transformer.num_layers = 2
|
||||
config.transformer.num_heads = 4
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
@@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
config.objective = DiffusionConfig(
|
||||
# Use diffusion objective
|
||||
objective="diffusion",
|
||||
noise_scheduler_type="DDPM",
|
||||
num_train_timesteps=100,
|
||||
num_inference_steps=10,
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
@@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
config.objective = FlowMatchingConfig(
|
||||
# Use flow matching objective
|
||||
objective="flow_matching",
|
||||
sigma_min=0.0,
|
||||
num_integration_steps=10, # Use fewer steps for faster tests
|
||||
num_integration_steps=10, # Fewer steps for faster tests
|
||||
integration_method="euler",
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
@@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params():
|
||||
# Second group is vision encoder params with different lr
|
||||
assert "params" in param_groups[1]
|
||||
assert "lr" in param_groups[1]
|
||||
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
|
||||
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
|
||||
assert param_groups[1]["lr"] == expected_lr
|
||||
|
||||
Reference in New Issue
Block a user