refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit

This commit is contained in:
Bryson Jones
2025-12-11 12:09:54 -08:00
parent 51dfee43f4
commit 71f359ca6e
4 changed files with 125 additions and 55 deletions
@@ -79,6 +79,10 @@ class MultiTaskDiTConfig(PreTrainedConfig):
# Text Encoder (CLIP) # Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization # Normalization
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
@@ -36,12 +36,18 @@ import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel from transformers import CLIPTextModel, CLIPVisionModel
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy -- # -- Policy --
@@ -127,36 +133,34 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
if self.config.image_features: if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations""" """Predict a chunk of actions given environment observations"""
self.eval() self.eval()
original_batch_keys = set(batch.keys()) for k in batch:
new_batch = {} if k in self._queues:
for k in self._queues: batch[k] = torch.stack(list(self._queues[k]), dim=1)
if k in original_batch_keys:
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
new_batch[k] = self._queues[k][-1]
else:
queue_values = list(self._queues[k])
new_batch[k] = torch.stack(queue_values, dim=1)
batch = new_batch
actions = self._generate_actions(batch) actions = self._generate_actions(batch)
return actions return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations""" """Select a single action given environment observations"""
if ACTION in batch: if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION) batch.pop(ACTION)
if self.config.image_features: batch = self._prepare_batch(batch)
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@@ -169,9 +173,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training""" """Run the batch through the model and compute the loss for training"""
if self.config.image_features: batch = self._prepare_batch(batch)
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
conditioning_vec = self.observation_encoder.encode(batch) conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
@@ -204,13 +206,16 @@ class CLIPVisionEncoder(nn.Module):
class CLIPTextEncoder(nn.Module): class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer.""" """CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__() super().__init__()
self.model_name = model_name self.model_name = model_name
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name) self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters(): for param in self.text_encoder.parameters():
@@ -219,16 +224,15 @@ class CLIPTextEncoder(nn.Module):
self.text_embed_dim = self.text_encoder.config.hidden_size self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim) self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, text: str | list[str]) -> Tensor: def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode text to feature vectors.""" """Encode pre-tokenized text to feature vectors."""
if isinstance(text, str): # Ensure inputs are on the same device as the model
text = [text] device = next(self.parameters()).device
input_ids = input_ids.to(device)
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") attention_mask = attention_mask.to(device)
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
with torch.no_grad(): with torch.no_grad():
outputs = self.text_encoder(**text_inputs) outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output clip_features = outputs.pooler_output
return self.projection(clip_features) return self.projection(clip_features)
@@ -349,8 +353,12 @@ class ObservationEncoder(nn.Module):
) )
conditioning_feats.append(img_features) conditioning_feats.append(img_features)
if self.text_encoder is not None and "task" in batch: if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
text_features = self.text_encoder(batch["task"]) input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features) conditioning_feats.append(text_features)
@@ -26,6 +26,7 @@ from lerobot.processor import (
PolicyAction, PolicyAction,
PolicyProcessorPipeline, PolicyProcessorPipeline,
RenameObservationsProcessorStep, RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep, UnnormalizerProcessorStep,
) )
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
@@ -45,8 +46,9 @@ def make_multi_task_dit_pre_post_processors(
The pre-processing pipeline prepares the input data for the model by: The pre-processing pipeline prepares the input data for the model by:
1. Renaming features. 1. Renaming features.
2. Adding a batch dimension. 2. Adding a batch dimension.
3. Moving the data to the specified device. 3. Tokenizing the language task description (if present).
4. Normalizing the input and output features based on dataset statistics. 4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by: The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale. 1. Unnormalizing the output features to their original scale.
@@ -65,6 +67,13 @@ def make_multi_task_dit_pre_post_processors(
input_steps = [ input_steps = [
RenameObservationsProcessorStep(rename_map={}), RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
@@ -30,7 +30,13 @@ from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiT
from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors, make_multi_task_dit_pre_post_processors,
) )
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from lerobot.utils.random_utils import seeded_context, set_seed from lerobot.utils.random_utils import seeded_context, set_seed
@@ -132,6 +138,14 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.train() policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch( batch = create_train_batch(
batch_size=batch_size, batch_size=batch_size,
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
@@ -140,8 +154,11 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
action_dim=action_dim, action_dim=action_dim,
) )
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass # Test forward pass
loss, _ = policy.forward(batch) loss, _ = policy.forward(processed_batch)
assert loss is not None assert loss is not None
assert loss.item() is not None assert loss.item() is not None
assert loss.shape == () assert loss.shape == ()
@@ -204,7 +221,11 @@ def test_multi_task_dit_pre_post_processors():
assert processed_batch["observation.state"].shape == (1, state_dim) assert processed_batch["observation.state"].shape == (1, state_dim)
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224) assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
assert processed_batch[ACTION].shape == (1, action_dim) assert processed_batch[ACTION].shape == (1, action_dim)
assert "task" in processed_batch # Check that task text was tokenized
assert OBS_LANGUAGE_TOKENS in processed_batch
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
# Check that data is on correct device # Check that data is on correct device
assert processed_batch["observation.state"].device.type == "cpu" assert processed_batch["observation.state"].device.type == "cpu"
@@ -360,6 +381,14 @@ def test_multi_task_dit_policy_diffusion_objective():
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.train() policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch( batch = create_train_batch(
batch_size=batch_size, batch_size=batch_size,
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
@@ -368,8 +397,11 @@ def test_multi_task_dit_policy_diffusion_objective():
action_dim=action_dim, action_dim=action_dim,
) )
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass # Test forward pass
loss, _ = policy.forward(batch) loss, _ = policy.forward(processed_batch)
assert loss is not None assert loss is not None
assert loss.item() is not None assert loss.item() is not None
@@ -427,6 +459,14 @@ def test_multi_task_dit_policy_flow_matching_objective():
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.train() policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch( batch = create_train_batch(
batch_size=batch_size, batch_size=batch_size,
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
@@ -435,8 +475,11 @@ def test_multi_task_dit_policy_flow_matching_objective():
action_dim=action_dim, action_dim=action_dim,
) )
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass # Test forward pass
loss, _ = policy.forward(batch) loss, _ = policy.forward(processed_batch)
assert loss is not None assert loss is not None
assert loss.item() is not None assert loss.item() is not None
@@ -499,33 +542,39 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
action_dim=action_dim, action_dim=action_dim,
) )
# Move batch to the same device as the policy # Use preprocessor to handle tokenization
for key in batch: config.normalization_mapping = {
if isinstance(batch[key], torch.Tensor): "VISUAL": NormalizationMode.IDENTITY,
batch[key] = batch[key].to(device) "STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad(): with torch.no_grad():
with seeded_context(12): with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving # Collect policy values before saving
loss, _ = policy.forward(batch) loss, _ = policy.forward(processed_batch)
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Move observation batch to device # Process observation through preprocessor
for key in observation_batch: processed_obs = preprocessor(observation_batch)
if isinstance(observation_batch[key], torch.Tensor): actions = policy.select_action(processed_obs)
observation_batch[key] = observation_batch[key].to(device)
actions = policy.select_action(observation_batch)
with seeded_context(12): with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Collect policy values after loading # Collect policy values after loading
loaded_loss, _ = loaded_policy.forward(batch) loaded_loss, _ = loaded_policy.forward(processed_batch)
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Move observation batch to device processed_obs = preprocessor(loaded_observation_batch)
for key in loaded_observation_batch: loaded_actions = loaded_policy.select_action(processed_obs)
if isinstance(loaded_observation_batch[key], torch.Tensor):
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
# Compare state dicts # Compare state dicts
assert policy.state_dict().keys() == loaded_policy.state_dict().keys() assert policy.state_dict().keys() == loaded_policy.state_dict().keys()