mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit
This commit is contained in:
@@ -30,7 +30,13 @@ from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiT
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
|
||||
@@ -132,6 +138,14 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
@@ -140,8 +154,11 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
assert loss.shape == ()
|
||||
@@ -204,7 +221,11 @@ def test_multi_task_dit_pre_post_processors():
|
||||
assert processed_batch["observation.state"].shape == (1, state_dim)
|
||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
||||
assert processed_batch[ACTION].shape == (1, action_dim)
|
||||
assert "task" in processed_batch
|
||||
# Check that task text was tokenized
|
||||
assert OBS_LANGUAGE_TOKENS in processed_batch
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
|
||||
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
|
||||
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
|
||||
|
||||
# Check that data is on correct device
|
||||
assert processed_batch["observation.state"].device.type == "cpu"
|
||||
@@ -360,6 +381,14 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
@@ -368,8 +397,11 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
@@ -427,6 +459,14 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
@@ -435,8 +475,11 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
@@ -499,33 +542,39 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Move batch to the same device as the policy
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device)
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Move batch to the same device as the policy
|
||||
for key in processed_batch:
|
||||
if isinstance(processed_batch[key], torch.Tensor):
|
||||
processed_batch[key] = processed_batch[key].to(device)
|
||||
# Collect policy values before saving
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Move observation batch to device
|
||||
for key in observation_batch:
|
||||
if isinstance(observation_batch[key], torch.Tensor):
|
||||
observation_batch[key] = observation_batch[key].to(device)
|
||||
actions = policy.select_action(observation_batch)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
actions = policy.select_action(processed_obs)
|
||||
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Collect policy values after loading
|
||||
loaded_loss, _ = loaded_policy.forward(batch)
|
||||
loaded_loss, _ = loaded_policy.forward(processed_batch)
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Move observation batch to device
|
||||
for key in loaded_observation_batch:
|
||||
if isinstance(loaded_observation_batch[key], torch.Tensor):
|
||||
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
processed_obs = preprocessor(loaded_observation_batch)
|
||||
loaded_actions = loaded_policy.select_action(processed_obs)
|
||||
|
||||
# Compare state dicts
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
|
||||
Reference in New Issue
Block a user