From 477204d485b1a7d8419a29557b7391bb0632d17a Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Wed, 28 Jan 2026 12:32:13 +0000 Subject: [PATCH] add eos to subtask token --- .../pi05_full/annotate/load_lerobot_high.py | 14 ++++++++++++- src/lerobot/policies/pi05_full/inference.py | 1 + .../policies/pi05_full/modeling_pi05.py | 1 + src/lerobot/processor/tokenizer_processor.py | 20 ++++++++++++++++++- 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py b/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py index 5a5d6efb1..7a48d6903 100644 --- a/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py +++ b/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py @@ -3,7 +3,8 @@ from huggingface_hub import HfApi import lerobot from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata - +from lerobot.policies.factory import make_pre_post_processors +from lerobot.configs.policies import PreTrainedConfig dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/libero-10-annotate") dataloader = torch.utils.data.DataLoader( @@ -13,7 +14,18 @@ dataloader = torch.utils.data.DataLoader( shuffle=True, ) +cfg = PreTrainedConfig.from_pretrained( + pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base", +) +cfg.dtype = "bfloat16" + +pre_processor, post_processor = make_pre_post_processors( + policy_cfg=cfg, + pretrained_path="/fsx/jade_choghari/models/pi05-base", +) batch = next(iter(dataloader)) +batch1 = pre_processor(batch) +breakpoint() print(batch.keys()) # print(batch['task_index_high_level'].shape) # print(batch['task_index_high_level']) diff --git a/src/lerobot/policies/pi05_full/inference.py b/src/lerobot/policies/pi05_full/inference.py index 68aee5e42..8c5b7fcd6 100644 --- a/src/lerobot/policies/pi05_full/inference.py +++ b/src/lerobot/policies/pi05_full/inference.py @@ -45,6 +45,7 @@ dataloader = torch.utils.data.DataLoader( ) batch = next(iter(dataloader)) +breakpoint() batch = pre_processor(batch) policy.train() # run inference diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index 996ef2afe..b1181765e 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -1717,6 +1717,7 @@ class PI05FullPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: + # TODO: jadechoghari, generate subtask tokens here - ideally every 1 second actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index c70773eee..c278f5c0e 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -254,6 +254,24 @@ class TokenizerProcessorStep(ObservationProcessorStep): if subtask is not None: tokenized_subtask = self._tokenize_text(subtask) + # Add EOS token at the end of each subtask sequence (before padding) + eos_token_id = self.input_tokenizer.eos_token_id + input_ids = tokenized_subtask["input_ids"] + attention_mask = tokenized_subtask["attention_mask"] + for i in range(input_ids.size(0)): + # Find the length of actual tokens (sum of attention mask) + seq_len = attention_mask[i].sum().item() + + max_len = input_ids.size(1) + if seq_len >= max_len: + raise ValueError( + f"No room to append EOS: seq_len={seq_len} equals max_length={max_len}. " + "Increase max_length or tokenize with padding=False then pad after adding EOS." + ) + # Add EOS token at the end + input_ids[i, seq_len] = eos_token_id + attention_mask[i, seq_len] = 1 + # Move new tokenized tensors to the detected device if target_device is not None: tokenized_subtask = { @@ -638,4 +656,4 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): Returns: The updated dictionary of policy features. """ - return features + return features \ No newline at end of file