diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index 8fd31b372..8286ab8e6 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -79,6 +79,10 @@ class MultiTaskDiTConfig(PreTrainedConfig): # Text Encoder (CLIP) text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77) + tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest" + tokenizer_padding_side: str = "right" # Padding side: "left" or "right" + tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length # Normalization normalization_mapping: dict[str, NormalizationMode] = field( diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index fe9a535fe..e646fd6fb 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -36,12 +36,18 @@ import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel +from transformers import CLIPTextModel, CLIPVisionModel from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import populate_queues -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) # -- Policy -- @@ -127,36 +133,34 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): if self.config.image_features: self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) - self._queues["task"] = deque(maxlen=self.config.n_obs_steps) - @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations""" self.eval() - original_batch_keys = set(batch.keys()) - new_batch = {} - for k in self._queues: - if k in original_batch_keys: - if self._queues[k] and isinstance(self._queues[k][-1][0], str): - new_batch[k] = self._queues[k][-1] - else: - queue_values = list(self._queues[k]) - new_batch[k] = torch.stack(queue_values, dim=1) - batch = new_batch + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) actions = self._generate_actions(batch) return actions + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare batch by stacking image features if needed.""" + if self.config.image_features: + batch = dict(batch) # shallow copy to avoid modifying original + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + return batch + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations""" if ACTION in batch: + batch = dict(batch) # shallow copy to avoid modifying original batch.pop(ACTION) - if self.config.image_features: - batch = dict(batch) - batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch = self._prepare_batch(batch) self._queues = populate_queues(self._queues, batch) @@ -169,9 +173,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: """Run the batch through the model and compute the loss for training""" - if self.config.image_features: - batch = dict(batch) - batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch = self._prepare_batch(batch) conditioning_vec = self.observation_encoder.encode(batch) loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) @@ -204,13 +206,16 @@ class CLIPVisionEncoder(nn.Module): class CLIPTextEncoder(nn.Module): - """CLIP text encoder with frozen weights and a learnable projection layer.""" + """CLIP text encoder with frozen weights and a learnable projection layer. + + Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor + pipeline to see how the tokenization is handled. + """ def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): super().__init__() self.model_name = model_name self.projection_dim = projection_dim - self.tokenizer = CLIPTokenizer.from_pretrained(model_name) self.text_encoder = CLIPTextModel.from_pretrained(model_name) for param in self.text_encoder.parameters(): @@ -219,16 +224,15 @@ class CLIPTextEncoder(nn.Module): self.text_embed_dim = self.text_encoder.config.hidden_size self.projection = nn.Linear(self.text_embed_dim, projection_dim) - def forward(self, text: str | list[str]) -> Tensor: - """Encode text to feature vectors.""" - if isinstance(text, str): - text = [text] - - text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") - text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()} + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Encode pre-tokenized text to feature vectors.""" + # Ensure inputs are on the same device as the model + device = next(self.parameters()).device + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) with torch.no_grad(): - outputs = self.text_encoder(**text_inputs) + outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) clip_features = outputs.pooler_output return self.projection(clip_features) @@ -349,8 +353,12 @@ class ObservationEncoder(nn.Module): ) conditioning_feats.append(img_features) - if self.text_encoder is not None and "task" in batch: - text_features = self.text_encoder(batch["task"]) + if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch: + input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length] + attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length] + + text_features = self.text_encoder(input_ids, attention_mask) + text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) conditioning_feats.append(text_features) diff --git a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py index b634f5438..fc94599c2 100644 --- a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py @@ -26,6 +26,7 @@ from lerobot.processor import ( PolicyAction, PolicyProcessorPipeline, RenameObservationsProcessorStep, + TokenizerProcessorStep, UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action @@ -45,8 +46,9 @@ def make_multi_task_dit_pre_post_processors( The pre-processing pipeline prepares the input data for the model by: 1. Renaming features. 2. Adding a batch dimension. - 3. Moving the data to the specified device. - 4. Normalizing the input and output features based on dataset statistics. + 3. Tokenizing the language task description (if present). + 4. Moving the data to the specified device. + 5. Normalizing the input and output features based on dataset statistics. The post-processing pipeline handles the model's output by: 1. Unnormalizing the output features to their original scale. @@ -65,6 +67,13 @@ def make_multi_task_dit_pre_post_processors( input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), + TokenizerProcessorStep( + tokenizer_name=config.text_encoder_name, + padding=config.tokenizer_padding, + padding_side=config.tokenizer_padding_side, + max_length=config.tokenizer_max_length, + truncation=config.tokenizer_truncation, + ), DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py index c12dcb2fb..a371d0bc8 100644 --- a/tests/policies/multi_task_dit/test_multi_task_dit.py +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -30,7 +30,13 @@ from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiT from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( make_multi_task_dit_pre_post_processors, ) -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) from lerobot.utils.random_utils import seeded_context, set_seed @@ -132,6 +138,14 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d policy = MultiTaskDiTPolicy(config=config) policy.train() + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + batch = create_train_batch( batch_size=batch_size, n_obs_steps=n_obs_steps, @@ -140,8 +154,11 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d action_dim=action_dim, ) + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + # Test forward pass - loss, _ = policy.forward(batch) + loss, _ = policy.forward(processed_batch) assert loss is not None assert loss.item() is not None assert loss.shape == () @@ -204,7 +221,11 @@ def test_multi_task_dit_pre_post_processors(): assert processed_batch["observation.state"].shape == (1, state_dim) assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224) assert processed_batch[ACTION].shape == (1, action_dim) - assert "task" in processed_batch + # Check that task text was tokenized + assert OBS_LANGUAGE_TOKENS in processed_batch + assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch + assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension + assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension # Check that data is on correct device assert processed_batch["observation.state"].device.type == "cpu" @@ -360,6 +381,14 @@ def test_multi_task_dit_policy_diffusion_objective(): policy = MultiTaskDiTPolicy(config=config) policy.train() + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + batch = create_train_batch( batch_size=batch_size, n_obs_steps=n_obs_steps, @@ -368,8 +397,11 @@ def test_multi_task_dit_policy_diffusion_objective(): action_dim=action_dim, ) + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + # Test forward pass - loss, _ = policy.forward(batch) + loss, _ = policy.forward(processed_batch) assert loss is not None assert loss.item() is not None @@ -427,6 +459,14 @@ def test_multi_task_dit_policy_flow_matching_objective(): policy = MultiTaskDiTPolicy(config=config) policy.train() + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + batch = create_train_batch( batch_size=batch_size, n_obs_steps=n_obs_steps, @@ -435,8 +475,11 @@ def test_multi_task_dit_policy_flow_matching_objective(): action_dim=action_dim, ) + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + # Test forward pass - loss, _ = policy.forward(batch) + loss, _ = policy.forward(processed_batch) assert loss is not None assert loss.item() is not None @@ -499,33 +542,39 @@ def test_multi_task_dit_policy_save_and_load(tmp_path): action_dim=action_dim, ) - # Move batch to the same device as the policy - for key in batch: - if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].to(device) + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) with torch.no_grad(): with seeded_context(12): + # Process batch through preprocessor + processed_batch = preprocessor(batch) + # Move batch to the same device as the policy + for key in processed_batch: + if isinstance(processed_batch[key], torch.Tensor): + processed_batch[key] = processed_batch[key].to(device) # Collect policy values before saving - loss, _ = policy.forward(batch) + loss, _ = policy.forward(processed_batch) observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - # Move observation batch to device - for key in observation_batch: - if isinstance(observation_batch[key], torch.Tensor): - observation_batch[key] = observation_batch[key].to(device) - actions = policy.select_action(observation_batch) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + actions = policy.select_action(processed_obs) with seeded_context(12): + # Process batch through preprocessor + processed_batch = preprocessor(batch) # Collect policy values after loading - loaded_loss, _ = loaded_policy.forward(batch) + loaded_loss, _ = loaded_policy.forward(processed_batch) loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - # Move observation batch to device - for key in loaded_observation_batch: - if isinstance(loaded_observation_batch[key], torch.Tensor): - loaded_observation_batch[key] = loaded_observation_batch[key].to(device) - loaded_actions = loaded_policy.select_action(loaded_observation_batch) + processed_obs = preprocessor(loaded_observation_batch) + loaded_actions = loaded_policy.select_action(processed_obs) # Compare state dicts assert policy.state_dict().keys() == loaded_policy.state_dict().keys()