diff --git a/examples/dataset/inference_pi05.py b/examples/dataset/inference_pi05.py index 73522694d..6f823cf1d 100644 --- a/examples/dataset/inference_pi05.py +++ b/examples/dataset/inference_pi05.py @@ -45,45 +45,14 @@ dataloader = torch.utils.data.DataLoader( batch = next(iter(dataloader)) batch = pre_processor(batch) + +# Test training forward pass policy.train() -# run inference -# action = policy.select_action(batch) loss, loss_dict = policy.forward(batch) -# import requests -# from PIL import Image -# from transformers import AutoProcessor -# model = policy.model.paligemma_with_expert.paligemma -# model = model.to(device="cuda", dtype=torch.bfloat16) -# model.eval() -# prompt = "Describe this image." -# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" -# image = Image.open(requests.get(url, stream=True).raw) -# processor = AutoProcessor.from_pretrained( -# "google/paligemma-3b-pt-224", -# ) -# inputs = processor(image, prompt, return_tensors="pt").to(model.device) -# print("generating...") -# output = model.generate( -# **inputs, -# max_new_tokens=50, -# use_cache=True, # default dynamic cache -# ) -# print(processor.decode(output[0], skip_special_tokens=True)) +print(f"Training loss: {loss_dict}") - -# # other model -# from transformers import PaliGemmaForConditionalGeneration -# model = PaliGemmaForConditionalGeneration.from_pretrained( -# "google/paligemma2-3b-pt-224", -# torch_dtype=torch.bfloat16, -# device_map="auto", -# ) -# model.eval() -# print("generating...") -# output = model.generate( -# **inputs, -# max_new_tokens=100, -# use_cache=True, # default dynamic cache -# ) -# print("Model 2 output:") -# print(processor.decode(output[0], skip_special_tokens=True)) +# Test inference +policy.eval() +with torch.no_grad(): + actions = policy.predict_action_chunk(batch) + print(f"Predicted actions shape: {actions.shape}") \ No newline at end of file diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 62a79c95c..713520f93 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -48,10 +48,10 @@ from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, - OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS, - OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK, - OBS_LANGUAGE_SUBTASK_ONLY_TOKENS, - OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK, + OBS_LANGUAGE_PROMPT_TOKENS, + OBS_LANGUAGE_PROMPT_ATTENTION_MASK, + OBS_LANGUAGE_TARGET_TOKENS, + OBS_LANGUAGE_TARGET_ATTENTION_MASK, OPENPI_ATTENTION_MASK_VALUE, ) @@ -609,21 +609,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return time.to(dtype=torch.float32, device=device) def embed_prefix( - self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks + self, images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks=None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - """Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer. + """Embed images with SigLIP, prompt tokens, and optionally target tokens with embedding layer. Args: images: List of image tensors img_masks: List of image masks - tokens: Language instruction tokens - subtask_tokens: Subtask tokens to predict (can be None for inference) - masks: Attention masks for tokens + prompt_tokens: Prompt tokens (input for generation) + target_tokens: Target tokens to predict (can be None for inference) + prompt_masks: Attention masks for prompt tokens + target_masks: Attention masks for target tokens Returns: - embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)] + embs: Concatenated embeddings [images, prompt_tokens, (target_tokens if provided)] pad_masks: Padding masks - att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided) + att_masks: Attention masks (with causal masking for target prediction if target_tokens provided) total_T_images: Total number of image tokens """ embs = [] @@ -645,36 +646,36 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` att_masks += [0] * num_img_embs # Images can attend to all previous tokens total_T_images += num_img_embs - # Process language instruction tokens - def lang_embed_func(tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + # Process prompt tokens + def prompt_embed_func(prompt_tokens): + prompt_emb = self.paligemma_with_expert.embed_language_tokens(prompt_tokens) + prompt_emb_dim = prompt_emb.shape[-1] + return prompt_emb * math.sqrt(prompt_emb_dim) - lang_emb = self._apply_checkpoint(lang_embed_func, tokens) - embs.append(lang_emb) - pad_masks.append(masks) + prompt_emb = self._apply_checkpoint(prompt_embed_func, prompt_tokens) + embs.append(prompt_emb) + pad_masks.append(prompt_masks) - num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens) + num_prompt_embs = prompt_emb.shape[1] + att_masks += [0] * num_prompt_embs # Prompt tokens can attend to all previous tokens (images + prompt) - # Process subtask tokens if provided (these are predicted, so use causal masking) - if subtask_tokens is not None: - def subtask_embed_func(subtask_tokens): - subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens) - subtask_emb_dim = subtask_emb.shape[-1] - return subtask_emb * math.sqrt(subtask_emb_dim) + # Process target tokens if provided (these are predicted, so use causal masking) + if target_tokens is not None: + def target_embed_func(target_tokens): + target_emb = self.paligemma_with_expert.embed_language_tokens(target_tokens) + target_emb_dim = target_emb.shape[-1] + return target_emb * math.sqrt(target_emb_dim) - subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens) - embs.append(subtask_emb) + target_emb = self._apply_checkpoint(target_embed_func, target_tokens) + embs.append(target_emb) - # Create subtask pad masks (non-zero tokens are valid) - pad_masks.append(subtask_masks) + # Create target pad masks (non-zero tokens are valid) + pad_masks.append(target_masks) - num_subtask_embs = subtask_emb.shape[1] - # Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens, - # and previous subtask tokens - att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens + num_target_embs = target_emb.shape[1] + # Causal masking for target tokens: each target token can attend to images, all prompt tokens, + # and previous target tokens + att_masks += [1] * num_target_embs # Use 1 for causal attention on target tokens embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) @@ -732,17 +733,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - # loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions) - def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor: + def forward(self, images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss. Args: images: List of image tensors img_masks: List of image masks - high_level_task: Instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:") - high_level_task_masks: Attention masks for high_level_task - subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup") - subtask_masks: Attention masks for subtask_tokens + prompt_tokens: Prompt tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:") + prompt_masks: Attention masks for prompt_tokens + target_tokens: Target tokens to predict (e.g., tokens for "pick up the cup") + target_masks: Attention masks for target_tokens actions: Ground truth actions noise: Optional noise for flow matching time: Optional time for flow matching @@ -757,20 +757,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - # Embed prefix (images + high_level_task + subtask_tokens) - # Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict + # Embed prefix (images + prompt_tokens + target_tokens) prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix( - images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks + images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks ) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) - # Prepare attention masks for prefix-only pass (for subtask token prediction) + # Prepare attention masks for prefix-only pass (for target token prediction) att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1 att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype) - # prefix-only transformer run for subtask token prediction + # prefix-only transformer run for target token prediction (prefix_out, _), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_prefix_4d, position_ids=position_ids_prefix, @@ -780,37 +779,33 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` adarms_cond=[None, None], ) - # LM HEAD → SUBTASK LOGITS - # prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask + # LM HEAD → TARGET LOGITS + # prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_prompt + T_target lm_head = self.paligemma_with_expert.paligemma.lm_head logits = lm_head(prefix_out) # (B, T_prefix, vocab) - # Extract logits for subtask token prediction - # In autoregressive modeling, output at position i predicts token at position i+1 - # So we take logits from one position earlier: - # - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0] - # - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc. - T_high_level_task = high_level_task.size(1) - T_subtask = subtask_tokens.size(1) - start_index = total_T_images + T_high_level_task - end_index = start_index + T_subtask - logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab) + # Extract logits for target token prediction (shifted by 1 for autoregressive training) + # Position i predicts token i+1, so we take logits from positions before target tokens: + # - Position (start_index-1) (last prompt token) predicts target_tokens[0] + # - Position (start_index) (first target token) predicts target_tokens[1], etc. + T_prompt = prompt_tokens.size(1) + T_target = target_tokens.size(1) + start_index = total_T_images + T_prompt + end_index = start_index + T_target + logits_target = logits[:, start_index-1:end_index-1, :] # (B, T_target, vocab) - targets = subtask_tokens # (B, T_subtask) # Compute cross-entropy loss loss_fct = torch.nn.CrossEntropyLoss(reduction='none') # Reshape for loss computation - logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab) - targets_flat = targets.reshape(-1) # (B*T_subtask) + logits_flat = logits_target.reshape(-1, logits_target.size(-1)) # (B*T_target, vocab) + targets_flat = target_tokens.reshape(-1) # (B*T_target) - loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask) - loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask) + loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_target) + loss_per_token = loss_per_token.reshape(target_tokens.shape) # (B, T_target) # Apply mask and compute mean loss over valid tokens - masked_loss = loss_per_token * subtask_masks.float() - subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1) - - breakpoint() + masked_loss = loss_per_token * target_masks.float() + target_loss = masked_loss.sum() / target_masks.sum().clamp(min=1) # Convert embeddings to bfloat16 if needed for the model if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype @@ -819,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_embs = suffix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16) - # Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks + # Concatenate prefix (images + prompt_tokens + target_tokens) and suffix (actions) masks pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -856,27 +851,26 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return { "flow_loss": fm_loss, - "subtask_loss": subtask_loss, - "loss": 10 * fm_loss.mean() + subtask_loss, + "target_loss": target_loss, + "loss": 10 * fm_loss.mean() + target_loss, } @torch.no_grad() - def _generate_subtask_tokens( - self, images, img_masks, tokens, masks, tokenizer, max_length, device + def _generate_target_tokens( + self, images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_length, device ): - """Generate subtask tokens autoregressively using next token prediction.""" - bsize = tokens.shape[0] + """Generate target tokens autoregressively using next token prediction.""" + bsize = prompt_tokens.shape[0] # Get lm_head for token generation lm_head = self.paligemma_with_expert.paligemma.lm_head - # Embed prefix without subtask tokens first + # Embed prefix without target tokens first prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix( - images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None + images, img_masks, prompt_tokens, target_tokens=None, prompt_masks=prompt_masks, target_masks=None ) - # Initialize generated tokens list - start with BOS token or first token after instruction - # For PaliGemma, we'll start generation and accumulate tokens + # Initialize generated tokens list generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device) for t in range(max_length): @@ -912,7 +906,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Embed the generated token and append to prefix next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1) - breakpoint() def next_token_embed_func(next_token_unsqueezed): next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed) @@ -938,20 +931,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` self, images, img_masks, - tokens, - masks, + prompt_tokens, + prompt_masks, noise=None, num_steps=None, tokenizer=None, - max_subtask_tokens=50, + max_target_tokens=50, **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: num_steps = self.config.num_inference_steps - bsize = tokens.shape[0] - device = tokens.device + bsize = prompt_tokens.shape[0] + device = prompt_tokens.device if noise is None: # Sample noise with padded dimension as expected by action_in_proj @@ -962,26 +955,28 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) - # Generate subtask tokens autoregressively during inference - generated_subtask_tokens = None + # Generate target tokens autoregressively during inference (if tokenizer provided) + generated_target_tokens = None + target_masks = None if tokenizer is not None: - generated_subtask_tokens = self._generate_subtask_tokens( - images, img_masks, tokens, masks, tokenizer, max_subtask_tokens, device + generated_target_tokens = self._generate_target_tokens( + images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_target_tokens, device ) - # Decode and print the generated subtask tokens + # Decode and print the generated target tokens for i in range(bsize): # Remove padding tokens (0) and special tokens - valid_tokens = generated_subtask_tokens[i][generated_subtask_tokens[i] != 0] + valid_tokens = generated_target_tokens[i][generated_target_tokens[i] != 0] decoded_text = tokenizer.decode(valid_tokens, skip_special_tokens=True) - print(f"[Inference] Generated subtask {i}: {decoded_text}") + print(f"[Inference] Generated target {i}: {decoded_text}") - # Create mask for generated tokens (all valid) - subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool) + # Create mask for generated tokens (all valid where token != 0) + target_masks = generated_target_tokens != 0 - # During inference, we don't have subtask_tokens yet, so pass None + # Embed prefix with prompt and optionally generated target tokens prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix( - images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks + images, img_masks, prompt_tokens, target_tokens=generated_target_tokens, + prompt_masks=prompt_masks, target_masks=target_masks ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -1416,13 +1411,13 @@ class PI05Policy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - # Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask - high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"] - high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"] - breakpoint() + # Use prompt tokens (WITHOUT target) for inference - we'll generate the target + prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"] + prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"] + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) actions = self.model.sample_actions( - images, img_masks, high_level_task, high_level_task_masks, + images, img_masks, prompt_tokens, prompt_masks, tokenizer=self.tokenizer, **kwargs ) @@ -1438,35 +1433,24 @@ class PI05Policy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"] - high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"] - subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"] + prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"] + prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"] + target_tokens, target_masks = batch[f"{OBS_LANGUAGE_TARGET_TOKENS}"], batch[f"{OBS_LANGUAGE_TARGET_ATTENTION_MASK}"] actions = self.prepare_action(batch) - # Decode and print ground truth subtask tokens during training - if self.tokenizer is not None and self.training: - bsize = subtask_tokens.shape[0] - for i in range(bsize): - # Remove padding tokens (0) and special tokens - valid_tokens = subtask_tokens[i][subtask_masks[i].bool()] - if len(valid_tokens) > 0: - decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True) - print(f"[Training] Ground truth subtask {i}: {decoded_text}") - - # Compute loss (no separate state needed for PI05) - # high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:") - # subtask_tokens = subtask tokens to predict (e.g., "pick up the cup") - loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions) + # Compute loss + # prompt_tokens = instruction tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:") + # target_tokens = target tokens to predict (e.g., "pick up the cup") + loss_dict = self.model.forward(images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions) # Extract the total loss loss = loss_dict["loss"] - breakpoint() # Prepare detailed loss dictionary for logging detailed_loss_dict = { "loss": loss.item(), "flow_loss": loss_dict["flow_loss"].mean().item(), - "subtask_loss": loss_dict["subtask_loss"].item(), + "target_loss": loss_dict["target_loss"].item(), } return loss, detailed_loss_dict diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 1b1fcf047..7563cecd0 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -54,8 +54,8 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): max_state_dim: int = 32 task_key: str = "task" - high_level_task_key: str = "user_prompt" - subtask_only_key: str = "subtask" + prompt_key: str = "prompt" + target_key: str = "target" def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() @@ -67,7 +67,7 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): if tasks is None: raise ValueError("No task found in complementary data") - high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.high_level_task_key) + high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("user_prompt") # TODO: check if this necessary state = deepcopy(state) @@ -86,36 +86,27 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): for high_level_task in high_level_tasks: cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " ")) - # Process low level tasks with state information - low_level_prompts = [] - subtask_only_prompts = [] # Store only the subtask text for prediction + # Process tasks to create prompts (input) and targets (what to predict) + prompts = [] # Input prompts ending with "Subtask:" + targets = [] # Target text to predict (the subtask) for i, task in enumerate(tasks): cleaned_text = task.strip().replace("_", " ").replace("\n", " ") state_str = " ".join(map(str, discretized_states[i])) - # Store only the subtask text (used as prediction target) - subtask_only_prompts.append(cleaned_text) + # Store the subtask text as target for prediction + targets.append(cleaned_text) if cleaned_high_level_tasks: cleaned_high_level_task = cleaned_high_level_tasks[i] - full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}" + # Prompt ends with "Subtask:" - model will predict the target + prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:" else: - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + raise ValueError("No high level tasks found") - low_level_prompts.append(full_prompt) + prompts.append(prompt) - transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts - transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_only_key] = subtask_only_prompts - - # Process high level tasks without state information (if available) - if high_level_tasks is not None: - high_level_prompts = [] - for i, cleaned_high_level_task in enumerate(cleaned_high_level_tasks): - state_str = " ".join(map(str, discretized_states[i])) - full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:" - high_level_prompts.append(full_prompt) - - transition[TransitionKey.COMPLEMENTARY_DATA][self.high_level_task_key] = high_level_prompts + transition[TransitionKey.COMPLEMENTARY_DATA][self.prompt_key] = prompts + transition[TransitionKey.COMPLEMENTARY_DATA][self.target_key] = targets return transition def transform_features( diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index abdaf41fc..a77f0af88 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -31,11 +31,11 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import ( OBS_LANGUAGE_ATTENTION_MASK, - OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK, - OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS, + OBS_LANGUAGE_PROMPT_ATTENTION_MASK, + OBS_LANGUAGE_PROMPT_TOKENS, OBS_LANGUAGE_TOKENS, - OBS_LANGUAGE_SUBTASK_ONLY_TOKENS, - OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK, + OBS_LANGUAGE_TARGET_TOKENS, + OBS_LANGUAGE_TARGET_ATTENTION_MASK, ) from lerobot.utils.import_utils import _transformers_available @@ -59,8 +59,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting token IDs and attention mask to the `observation` dictionary. - Optionally, this step can also tokenize a high-level task (e.g., user prompt) and/or - a subtask if present in the complementary data, creating separate tokenized observations. + Optionally, this step can also tokenize a prompt (input for generation) and/or + a target (text to predict) if present in the complementary data, creating separate tokenized observations. Requires the `transformers` library to be installed. @@ -69,8 +69,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored. max_length: The maximum length to pad or truncate sequences to. task_key: The key in `complementary_data` where the task string is stored. - high_level_task_key: The key in `complementary_data` where the high-level task (user prompt) is stored. - subtask_key: The key in `complementary_data` where the subtask string is stored. + prompt_key: The key in `complementary_data` where the prompt (input for generation) is stored. + target_key: The key in `complementary_data` where the target (text to predict) is stored. padding_side: The side to pad on ('left' or 'right'). padding: The padding strategy ('max_length', 'longest', etc.). truncation: Whether to truncate sequences longer than `max_length`. @@ -81,8 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency max_length: int = 512 task_key: str = "task" - high_level_task_key: str = "user_prompt" - subtask_key: str = "subtask" + prompt_key: str = "prompt" + target_key: str = "target" padding_side: str = "right" padding: str = "max_length" truncation: bool = True @@ -147,57 +147,57 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None - def get_high_level_task(self, transition: EnvTransition) -> list[str] | None: + def get_prompt(self, transition: EnvTransition) -> list[str] | None: """ - Extracts the high-level task description(s) from the transition's complementary data. + Extracts the prompt (input for generation) from the transition's complementary data. Args: transition: The environment transition. Returns: - A list of high-level task strings, or None if the high-level task key is not found or the value is None. + A list of prompt strings, or None if the prompt key is not found or the value is None. """ complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) if complementary_data is None: return None - high_level_task = complementary_data.get(self.high_level_task_key) + prompt = complementary_data.get(self.prompt_key) - if high_level_task is None: + if prompt is None: return None # Standardize to a list of strings for the tokenizer - if isinstance(high_level_task, str): - return [high_level_task] - elif isinstance(high_level_task, list) and all(isinstance(t, str) for t in high_level_task): - return high_level_task + if isinstance(prompt, str): + return [prompt] + elif isinstance(prompt, list) and all(isinstance(t, str) for t in prompt): + return prompt return None - def get_subtask(self, transition: EnvTransition) -> list[str] | None: + def get_target(self, transition: EnvTransition) -> list[str] | None: """ - Extracts the subtask description(s) from the transition's complementary data. + Extracts the target (text to predict) from the transition's complementary data. Args: transition: The environment transition. Returns: - A list of subtask strings, or None if the subtask key is not found or the value is None. + A list of target strings, or None if the target key is not found or the value is None. """ complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) if complementary_data is None: return None - subtask = complementary_data.get(self.subtask_key) + target = complementary_data.get(self.target_key) - if subtask is None: + if target is None: return None # Standardize to a list of strings for the tokenizer - if isinstance(subtask, str): - return [subtask] - elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): - return subtask + if isinstance(target, str): + return [target] + elif isinstance(target, list) and all(isinstance(t, str) for t in target): + return target return None @@ -238,39 +238,37 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) - # Also tokenize high-level task if available - high_level_task = self.get_high_level_task(self.transition) - if high_level_task is not None: - # Tokenize the high-level task - tokenized_high_level_prompt = self._tokenize_text(high_level_task) + # Also tokenize prompt (input for generation) if available + prompt = self.get_prompt(self.transition) + if prompt is not None: + tokenized_prompt_input = self._tokenize_text(prompt) # Move to the same device if target_device is not None: - tokenized_high_level_prompt = { + tokenized_prompt_input = { k: v.to(target_device) if isinstance(v, torch.Tensor) else v - for k, v in tokenized_high_level_prompt.items() + for k, v in tokenized_prompt_input.items() } - # Add high-level tokenized data to the observation - new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = tokenized_high_level_prompt["input_ids"] - new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = tokenized_high_level_prompt["attention_mask"].to(dtype=torch.bool) + # Add prompt tokenized data to the observation + new_observation[OBS_LANGUAGE_PROMPT_TOKENS] = tokenized_prompt_input["input_ids"] + new_observation[OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = tokenized_prompt_input["attention_mask"].to(dtype=torch.bool) - # Also tokenize subtask if available - subtask = self.get_subtask(self.transition) - if subtask is not None: - # Tokenize the subtask - tokenized_subtask_prompt = self._tokenize_text(subtask) + # Also tokenize target (text to predict) if available + target = self.get_target(self.transition) + if target is not None: + tokenized_target = self._tokenize_text(target) # Move to the same device if target_device is not None: - tokenized_subtask_prompt = { + tokenized_target = { k: v.to(target_device) if isinstance(v, torch.Tensor) else v - for k, v in tokenized_subtask_prompt.items() + for k, v in tokenized_target.items() } - # Add subtask tokenized data to the observation - new_observation[OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = tokenized_subtask_prompt["input_ids"] - new_observation[OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = tokenized_subtask_prompt["attention_mask"].to(dtype=torch.bool) + # Add target tokenized data to the observation + new_observation[OBS_LANGUAGE_TARGET_TOKENS] = tokenized_target["input_ids"] + new_observation[OBS_LANGUAGE_TARGET_ATTENTION_MASK] = tokenized_target["attention_mask"].to(dtype=torch.bool) return new_observation @@ -332,7 +330,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): config = { "max_length": self.max_length, "task_key": self.task_key, - "high_level_task_key": self.high_level_task_key, + "prompt_key": self.prompt_key, + "target_key": self.target_key, "padding_side": self.padding_side, "padding": self.padding, "truncation": self.truncation, @@ -371,24 +370,25 @@ class TokenizerProcessorStep(ObservationProcessorStep): type=FeatureType.LANGUAGE, shape=(self.max_length,) ) - # Add features for high-level task tokens and attention mask if they don't already exist - if OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]: - features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = PolicyFeature( + # Add features for prompt tokens and attention mask if they don't already exist + if OBS_LANGUAGE_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_TOKENS] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) ) - if OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: - features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = PolicyFeature( + if OBS_LANGUAGE_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) ) - if OBS_LANGUAGE_SUBTASK_ONLY_TOKENS not in features[PipelineFeatureType.OBSERVATION]: - features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = PolicyFeature( + # Add features for target tokens and attention mask if they don't already exist + if OBS_LANGUAGE_TARGET_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_TOKENS] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) ) - if OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: - features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = PolicyFeature( + if OBS_LANGUAGE_TARGET_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_ATTENTION_MASK] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) ) diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index c8e19eb56..413a41a27 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,12 +26,12 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" -OBS_LANGUAGE_HIGH_LEVEL_TASK = OBS_STR + ".user_prompt" -OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".tokens" -OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".attention_mask" -OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask" -OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens" -OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask" +OBS_LANGUAGE_PROMPT = OBS_STR + ".prompt" +OBS_LANGUAGE_PROMPT_TOKENS = OBS_LANGUAGE_PROMPT + ".tokens" +OBS_LANGUAGE_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_PROMPT + ".attention_mask" +OBS_LANGUAGE_TARGET = OBS_STR + ".target" +OBS_LANGUAGE_TARGET_TOKENS = OBS_LANGUAGE_TARGET + ".tokens" +OBS_LANGUAGE_TARGET_ATTENTION_MASK = OBS_LANGUAGE_TARGET + ".attention_mask" ACTION = "action" REWARD = "next.reward" TRUNCATED = "next.truncated"