mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit
This commit is contained in:
@@ -79,6 +79,10 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
|
||||
# Text Encoder (CLIP)
|
||||
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
|
||||
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
|
||||
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
|
||||
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
|
||||
@@ -36,12 +36,18 @@ import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
from transformers import CLIPTextModel, CLIPVisionModel
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
# -- Policy --
|
||||
|
||||
@@ -127,36 +133,34 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations"""
|
||||
self.eval()
|
||||
|
||||
original_batch_keys = set(batch.keys())
|
||||
new_batch = {}
|
||||
for k in self._queues:
|
||||
if k in original_batch_keys:
|
||||
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
|
||||
new_batch[k] = self._queues[k][-1]
|
||||
else:
|
||||
queue_values = list(self._queues[k])
|
||||
new_batch[k] = torch.stack(queue_values, dim=1)
|
||||
batch = new_batch
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
actions = self._generate_actions(batch)
|
||||
return actions
|
||||
|
||||
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Prepare batch by stacking image features if needed."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations"""
|
||||
if ACTION in batch:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch.pop(ACTION)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -169,9 +173,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Run the batch through the model and compute the loss for training"""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||
@@ -204,13 +206,16 @@ class CLIPVisionEncoder(nn.Module):
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer."""
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer.
|
||||
|
||||
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
|
||||
pipeline to see how the tokenization is handled.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.projection_dim = projection_dim
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
||||
|
||||
for param in self.text_encoder.parameters():
|
||||
@@ -219,16 +224,15 @@ class CLIPTextEncoder(nn.Module):
|
||||
self.text_embed_dim = self.text_encoder.config.hidden_size
|
||||
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
||||
|
||||
def forward(self, text: str | list[str]) -> Tensor:
|
||||
"""Encode text to feature vectors."""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
||||
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
|
||||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""Encode pre-tokenized text to feature vectors."""
|
||||
# Ensure inputs are on the same device as the model
|
||||
device = next(self.parameters()).device
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.text_encoder(**text_inputs)
|
||||
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
clip_features = outputs.pooler_output
|
||||
|
||||
return self.projection(clip_features)
|
||||
@@ -349,8 +353,12 @@ class ObservationEncoder(nn.Module):
|
||||
)
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
if self.text_encoder is not None and "task" in batch:
|
||||
text_features = self.text_encoder(batch["task"])
|
||||
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
|
||||
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
|
||||
|
||||
text_features = self.text_encoder(input_ids, attention_mask)
|
||||
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
@@ -45,8 +46,9 @@ def make_multi_task_dit_pre_post_processors(
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features.
|
||||
2. Adding a batch dimension.
|
||||
3. Moving the data to the specified device.
|
||||
4. Normalizing the input and output features based on dataset statistics.
|
||||
3. Tokenizing the language task description (if present).
|
||||
4. Moving the data to the specified device.
|
||||
5. Normalizing the input and output features based on dataset statistics.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output features to their original scale.
|
||||
@@ -65,6 +67,13 @@ def make_multi_task_dit_pre_post_processors(
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.text_encoder_name,
|
||||
padding=config.tokenizer_padding,
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
max_length=config.tokenizer_max_length,
|
||||
truncation=config.tokenizer_truncation,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
|
||||
@@ -30,7 +30,13 @@ from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiT
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
|
||||
@@ -132,6 +138,14 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
@@ -140,8 +154,11 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
assert loss.shape == ()
|
||||
@@ -204,7 +221,11 @@ def test_multi_task_dit_pre_post_processors():
|
||||
assert processed_batch["observation.state"].shape == (1, state_dim)
|
||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
||||
assert processed_batch[ACTION].shape == (1, action_dim)
|
||||
assert "task" in processed_batch
|
||||
# Check that task text was tokenized
|
||||
assert OBS_LANGUAGE_TOKENS in processed_batch
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
|
||||
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
|
||||
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
|
||||
|
||||
# Check that data is on correct device
|
||||
assert processed_batch["observation.state"].device.type == "cpu"
|
||||
@@ -360,6 +381,14 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
@@ -368,8 +397,11 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
@@ -427,6 +459,14 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
@@ -435,8 +475,11 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
@@ -499,33 +542,39 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Move batch to the same device as the policy
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device)
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Move batch to the same device as the policy
|
||||
for key in processed_batch:
|
||||
if isinstance(processed_batch[key], torch.Tensor):
|
||||
processed_batch[key] = processed_batch[key].to(device)
|
||||
# Collect policy values before saving
|
||||
loss, _ = policy.forward(batch)
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Move observation batch to device
|
||||
for key in observation_batch:
|
||||
if isinstance(observation_batch[key], torch.Tensor):
|
||||
observation_batch[key] = observation_batch[key].to(device)
|
||||
actions = policy.select_action(observation_batch)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
actions = policy.select_action(processed_obs)
|
||||
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Collect policy values after loading
|
||||
loaded_loss, _ = loaded_policy.forward(batch)
|
||||
loaded_loss, _ = loaded_policy.forward(processed_batch)
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Move observation batch to device
|
||||
for key in loaded_observation_batch:
|
||||
if isinstance(loaded_observation_batch[key], torch.Tensor):
|
||||
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
processed_obs = preprocessor(loaded_observation_batch)
|
||||
loaded_actions = loaded_policy.select_action(processed_obs)
|
||||
|
||||
# Compare state dicts
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
|
||||
Reference in New Issue
Block a user