refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit

This commit is contained in:
Bryson Jones
2025-12-11 12:09:54 -08:00
parent 51dfee43f4
commit 71f359ca6e
4 changed files with 125 additions and 55 deletions
@@ -79,6 +79,10 @@ class MultiTaskDiTConfig(PreTrainedConfig):
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
@@ -36,12 +36,18 @@ import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from transformers import CLIPTextModel, CLIPVisionModel
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy --
@@ -127,36 +133,34 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
original_batch_keys = set(batch.keys())
new_batch = {}
for k in self._queues:
if k in original_batch_keys:
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
new_batch[k] = self._queues[k][-1]
else:
queue_values = list(self._queues[k])
new_batch[k] = torch.stack(queue_values, dim=1)
batch = new_batch
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
actions = self._generate_actions(batch)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION)
if self.config.image_features:
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch)
@@ -169,9 +173,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
if self.config.image_features:
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self._prepare_batch(batch)
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
@@ -204,13 +206,16 @@ class CLIPVisionEncoder(nn.Module):
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer."""
"""CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
@@ -219,16 +224,15 @@ class CLIPTextEncoder(nn.Module):
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, text: str | list[str]) -> Tensor:
"""Encode text to feature vectors."""
if isinstance(text, str):
text = [text]
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode pre-tokenized text to feature vectors."""
# Ensure inputs are on the same device as the model
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = self.text_encoder(**text_inputs)
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output
return self.projection(clip_features)
@@ -349,8 +353,12 @@ class ObservationEncoder(nn.Module):
)
conditioning_feats.append(img_features)
if self.text_encoder is not None and "task" in batch:
text_features = self.text_encoder(batch["task"])
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
@@ -26,6 +26,7 @@ from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
@@ -45,8 +46,9 @@ def make_multi_task_dit_pre_post_processors(
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Adding a batch dimension.
3. Moving the data to the specified device.
4. Normalizing the input and output features based on dataset statistics.
3. Tokenizing the language task description (if present).
4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale.
@@ -65,6 +67,13 @@ def make_multi_task_dit_pre_post_processors(
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
@@ -30,7 +30,13 @@ from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiT
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -132,6 +138,14 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
@@ -140,8 +154,11 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(batch)
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
assert loss.shape == ()
@@ -204,7 +221,11 @@ def test_multi_task_dit_pre_post_processors():
assert processed_batch["observation.state"].shape == (1, state_dim)
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
assert processed_batch[ACTION].shape == (1, action_dim)
assert "task" in processed_batch
# Check that task text was tokenized
assert OBS_LANGUAGE_TOKENS in processed_batch
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
# Check that data is on correct device
assert processed_batch["observation.state"].device.type == "cpu"
@@ -360,6 +381,14 @@ def test_multi_task_dit_policy_diffusion_objective():
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
@@ -368,8 +397,11 @@ def test_multi_task_dit_policy_diffusion_objective():
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(batch)
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
@@ -427,6 +459,14 @@ def test_multi_task_dit_policy_flow_matching_objective():
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
@@ -435,8 +475,11 @@ def test_multi_task_dit_policy_flow_matching_objective():
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(batch)
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
@@ -499,33 +542,39 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
action_dim=action_dim,
)
# Move batch to the same device as the policy
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device)
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving
loss, _ = policy.forward(batch)
loss, _ = policy.forward(processed_batch)
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Move observation batch to device
for key in observation_batch:
if isinstance(observation_batch[key], torch.Tensor):
observation_batch[key] = observation_batch[key].to(device)
actions = policy.select_action(observation_batch)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
actions = policy.select_action(processed_obs)
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Collect policy values after loading
loaded_loss, _ = loaded_policy.forward(batch)
loaded_loss, _ = loaded_policy.forward(processed_batch)
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Move observation batch to device
for key in loaded_observation_batch:
if isinstance(loaded_observation_batch[key], torch.Tensor):
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
processed_obs = preprocessor(loaded_observation_batch)
loaded_actions = loaded_policy.select_action(processed_obs)
# Compare state dicts
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()