diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 2edd625af..60ea6be87 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -60,8 +60,8 @@ class PI05Config(PreTrainedConfig): normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state - "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action + "STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state + "ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action } ) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 6500ada20..76853f1a0 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -48,6 +48,9 @@ from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, + OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS, + OBS_LANGUAGE_SUBTASK_ONLY_TOKENS, + OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK, OPENPI_ATTENTION_MASK_VALUE, ) @@ -429,6 +432,8 @@ class PaliGemmaWithExpertModel( adarms_cond=adarms_cond[0] if adarms_cond is not None else None, ) prefix_past_key_values = prefix_output.past_key_values + # prefix_output to be used for the language head + # shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048 prefix_output = prefix_output.last_hidden_state suffix_output = None elif inputs_embeds[0] is None: @@ -578,10 +583,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) return func(*args, **kwargs) - def _prepare_attention_masks_4d(self, att_2d_masks): + def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None): """Helper method to prepare 4D attention masks for transformer.""" att_2d_masks_4d = att_2d_masks[:, None, :, :] - return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + if dtype is not None: + result = result.to(dtype=dtype) + return result def sample_noise(self, shape, device): return torch.normal( @@ -600,13 +608,28 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return time.to(dtype=torch.float32, device=device) def embed_prefix( - self, images, img_masks, tokens, masks - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer.""" + self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer. + + Args: + images: List of image tensors + img_masks: List of image masks + tokens: Language instruction tokens + subtask_tokens: Subtask tokens to predict (can be None for inference) + masks: Attention masks for tokens + + Returns: + embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)] + pad_masks: Padding masks + att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided) + total_T_images: Total number of image tokens + """ embs = [] pad_masks = [] att_masks = [] - + total_T_images = 0 + # Process images for img, img_mask in zip(images, img_masks, strict=True): @@ -618,9 +641,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) - att_masks += [0] * num_img_embs - - # Process language tokens + att_masks += [0] * num_img_embs # Images can attend to all previous tokens + total_T_images += num_img_embs + + # Process language instruction tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) lang_emb_dim = lang_emb.shape[-1] @@ -631,16 +655,34 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` pad_masks.append(masks) num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs + att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens) + + # Process subtask tokens if provided (these are predicted, so use causal masking) + if subtask_tokens is not None: + def subtask_embed_func(subtask_tokens): + subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens) + subtask_emb_dim = subtask_emb.shape[-1] + return subtask_emb * math.sqrt(subtask_emb_dim) + + subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens) + embs.append(subtask_emb) + + # Create subtask pad masks (non-zero tokens are valid) + pad_masks.append(subtask_masks) + + num_subtask_embs = subtask_emb.shape[1] + # Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens, + # and previous subtask tokens + att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0]) - return embs, pad_masks, att_masks + return embs, pad_masks, att_masks, total_T_images def embed_suffix(self, noisy_actions, timestep): """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" @@ -689,7 +731,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: + # loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions) + def forward(self, images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss.""" if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -701,9 +744,55 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + # Embed prefix (images + tokens + subtask_tokens) + prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix( + images, img_masks, tokens, subtask_tokens, masks, subtask_masks + ) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + # Prepare attention masks for prefix-only pass (for subtask token prediction) + att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1 + att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype) + + # prefix-only transformer run for subtask token prediction + (prefix_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_prefix_4d, + position_ids=position_ids_prefix, + past_key_values=None, + inputs_embeds=[prefix_embs, None], # SUFFIX = None + use_cache=False, + adarms_cond=[None, None], + ) + + # LM HEAD → SUBTASK LOGITS + # prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_tokens + T_subtask + lm_head = self.paligemma_with_expert.paligemma.lm_head + logits = lm_head(prefix_out) # (B, T_prefix, vocab) + + # Extract logits for subtask token prediction + # Subtask tokens start after images and instruction tokens + T_tokens = tokens.size(1) + T_subtask = subtask_tokens.size(1) + start_index = total_T_images + T_tokens + end_index = start_index + T_subtask + logits_subtask = logits[:, start_index:end_index, :] # (B, T_subtask, vocab) + + targets = subtask_tokens # (B, T_subtask) + # Compute cross-entropy loss + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + # Reshape for loss computation + logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab) + targets_flat = targets.reshape(-1) # (B*T_subtask) + + loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask) + loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask) + + # Apply mask and compute mean loss over valid tokens + masked_loss = loss_per_token * subtask_masks.float() + subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1) + + # Convert embeddings to bfloat16 if needed for the model if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 @@ -711,13 +800,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_embs = suffix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + # Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + # Prepare attention masks for full forward pass (prefix + suffix) att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 - - att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype) def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( @@ -728,6 +818,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` use_cache=False, adarms_cond=[None, adarms_cond], ) + # prefix_out to be used for the language head return suffix_out suffix_out = self._apply_checkpoint( @@ -742,7 +833,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) - return F.mse_loss(u_t, v_t, reduction="none") + fm_loss = F.mse_loss(u_t, v_t, reduction="none") + + return { + "flow_loss": fm_loss, + "subtask_loss": subtask_loss, + "loss": 10 * fm_loss.mean() + subtask_loss, + } @torch.no_grad() # see openpi `sample_actions` (slightly adapted) def sample_actions( @@ -771,11 +868,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + # During inference, we don't need subtask_tokens, so pass None + prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix( + images, img_masks, tokens, subtask_tokens=None, masks=masks + ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype) self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( @@ -852,7 +952,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=suffix_embs.dtype) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 outputs_embeds, _ = self.paligemma_with_expert.forward( @@ -1198,7 +1298,7 @@ class PI05Policy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) @@ -1214,21 +1314,22 @@ class PI05Policy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - + subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"] + high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"] actions = self.prepare_action(batch) - + # Compute loss (no separate state needed for PI05) - losses = self.model.forward(images, img_masks, tokens, masks, actions) + # high_level_task = instruction tokens, tokens = subtask tokens to predict + loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions) - # Truncate losses to actual action dimensions - original_action_dim = self.config.output_features[ACTION].shape[0] - losses = losses[:, :, :original_action_dim] - - loss = losses.mean() - - loss_dict = { + # Extract the total loss + loss = loss_dict["loss"] + + # Prepare detailed loss dictionary for logging + detailed_loss_dict = { "loss": loss.item(), - "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + "flow_loss": loss_dict["flow_loss"].mean().item(), + "subtask_loss": loss_dict["subtask_loss"].item(), } - return loss, loss_dict + return loss, detailed_loss_dict diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index e29bc4c23..1b1fcf047 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -47,13 +47,15 @@ from lerobot.utils.constants import ( @ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") @dataclass -class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): +class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): """ Processor step to prepare the state and tokenize the language input. """ max_state_dim: int = 32 task_key: str = "task" + high_level_task_key: str = "user_prompt" + subtask_only_key: str = "subtask" def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() @@ -64,6 +66,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) if tasks is None: raise ValueError("No task found in complementary data") + + high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.high_level_task_key) # TODO: check if this necessary state = deepcopy(state) @@ -76,16 +80,42 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): state_np = state.cpu().numpy() discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - full_prompts = [] + # Clean high level tasks first (if available) + cleaned_high_level_tasks = [] + if high_level_tasks is not None: + for high_level_task in high_level_tasks: + cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " ")) + + # Process low level tasks with state information + low_level_prompts = [] + subtask_only_prompts = [] # Store only the subtask text for prediction for i, task in enumerate(tasks): cleaned_text = task.strip().replace("_", " ").replace("\n", " ") state_str = " ".join(map(str, discretized_states[i])) - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " - full_prompts.append(full_prompt) + + # Store only the subtask text (used as prediction target) + subtask_only_prompts.append(cleaned_text) + + if cleaned_high_level_tasks: + cleaned_high_level_task = cleaned_high_level_tasks[i] + full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}" + else: + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + + low_level_prompts.append(full_prompt) - transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts - # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) - # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts + transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_only_key] = subtask_only_prompts + + # Process high level tasks without state information (if available) + if high_level_tasks is not None: + high_level_prompts = [] + for i, cleaned_high_level_task in enumerate(cleaned_high_level_tasks): + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:" + high_level_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.high_level_task_key] = high_level_prompts return transition def transform_features( @@ -133,14 +163,14 @@ def make_pi05_pre_post_processors( input_steps: list[ProcessorStep] = [ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one AddBatchDimensionProcessorStep(), - # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep + # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateAndLanguageTokenizerProcessorStep # because the tokenizer step expects normalized state in [-1, 1] range for discretization NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, ), - Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + Pi05PrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim), TokenizerProcessorStep( tokenizer_name="google/paligemma-3b-pt-224", max_length=config.tokenizer_max_length, diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 6b0b67598..d20526a1d 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -168,10 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} + user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {} + subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key} + return {**pad_keys, **task_key, **index_key, **task_index_key, **user_prompt_key, **subtask_key} def create_transition( diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 6cae5921f..1456100fd 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -47,7 +47,6 @@ class RenameObservationsProcessorStep(ObservationProcessorStep): processed_obs[self.rename_map[key]] = value else: processed_obs[key] = value - return processed_obs def get_config(self) -> dict[str, Any]: diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 2ef89c107..abdaf41fc 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -29,7 +29,14 @@ from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import ( + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK, + OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS, + OBS_LANGUAGE_TOKENS, + OBS_LANGUAGE_SUBTASK_ONLY_TOKENS, + OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK, +) from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey @@ -52,6 +59,9 @@ class TokenizerProcessorStep(ObservationProcessorStep): tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting token IDs and attention mask to the `observation` dictionary. + Optionally, this step can also tokenize a high-level task (e.g., user prompt) and/or + a subtask if present in the complementary data, creating separate tokenized observations. + Requires the `transformers` library to be installed. Attributes: @@ -59,6 +69,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored. max_length: The maximum length to pad or truncate sequences to. task_key: The key in `complementary_data` where the task string is stored. + high_level_task_key: The key in `complementary_data` where the high-level task (user prompt) is stored. + subtask_key: The key in `complementary_data` where the subtask string is stored. padding_side: The side to pad on ('left' or 'right'). padding: The padding strategy ('max_length', 'longest', etc.). truncation: Whether to truncate sequences longer than `max_length`. @@ -69,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency max_length: int = 512 task_key: str = "task" + high_level_task_key: str = "user_prompt" + subtask_key: str = "subtask" padding_side: str = "right" padding: str = "max_length" truncation: bool = True @@ -121,6 +135,7 @@ class TokenizerProcessorStep(ObservationProcessorStep): raise ValueError("Complementary data is None so no task can be extracted from it") task = complementary_data[self.task_key] + if task is None: raise ValueError("Task extracted from Complementary data is None") @@ -132,6 +147,60 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_high_level_task(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the high-level task description(s) from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of high-level task strings, or None if the high-level task key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + high_level_task = complementary_data.get(self.high_level_task_key) + + if high_level_task is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(high_level_task, str): + return [high_level_task] + elif isinstance(high_level_task, list) and all(isinstance(t, str) for t in high_level_task): + return high_level_task + + return None + + def get_subtask(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the subtask description(s) from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of subtask strings, or None if the subtask key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + subtask = complementary_data.get(self.subtask_key) + + if subtask is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(subtask, str): + return [subtask] + elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): + return subtask + + return None + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: """ Tokenizes the task description and adds it to the observation dictionary. @@ -169,6 +238,40 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + # Also tokenize high-level task if available + high_level_task = self.get_high_level_task(self.transition) + if high_level_task is not None: + # Tokenize the high-level task + tokenized_high_level_prompt = self._tokenize_text(high_level_task) + + # Move to the same device + if target_device is not None: + tokenized_high_level_prompt = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_high_level_prompt.items() + } + + # Add high-level tokenized data to the observation + new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = tokenized_high_level_prompt["input_ids"] + new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = tokenized_high_level_prompt["attention_mask"].to(dtype=torch.bool) + + # Also tokenize subtask if available + subtask = self.get_subtask(self.transition) + if subtask is not None: + # Tokenize the subtask + tokenized_subtask_prompt = self._tokenize_text(subtask) + + # Move to the same device + if target_device is not None: + tokenized_subtask_prompt = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_subtask_prompt.items() + } + + # Add subtask tokenized data to the observation + new_observation[OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = tokenized_subtask_prompt["input_ids"] + new_observation[OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = tokenized_subtask_prompt["attention_mask"].to(dtype=torch.bool) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: @@ -229,6 +332,7 @@ class TokenizerProcessorStep(ObservationProcessorStep): config = { "max_length": self.max_length, "task_key": self.task_key, + "high_level_task_key": self.high_level_task_key, "padding_side": self.padding_side, "padding": self.padding, "truncation": self.truncation, @@ -267,4 +371,25 @@ class TokenizerProcessorStep(ObservationProcessorStep): type=FeatureType.LANGUAGE, shape=(self.max_length,) ) + # Add features for high-level task tokens and attention mask if they don't already exist + if OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + if OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + if OBS_LANGUAGE_SUBTASK_ONLY_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + if OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + return features diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index dfa10b2e5..c8e19eb56 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,7 +26,12 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" - +OBS_LANGUAGE_HIGH_LEVEL_TASK = OBS_STR + ".user_prompt" +OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".tokens" +OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".attention_mask" +OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask" +OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens" +OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask" ACTION = "action" REWARD = "next.reward" TRUNCATED = "next.truncated" diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index 0d5244e1c..2bfb43148 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -266,7 +266,7 @@ def create_original_observation_with_openpi_preprocessing(batch): elif len(tasks) == 1: tasks = tasks * batch_size - # Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep) + # Use pi05 state and input tokenizer logic (same as Pi05PrepareStateAndLanguageTokenizerProcessorStep) state = batch["observation.state"] state = deepcopy(state)