From 0059ca7924518d5c36dc52bf721acf636e3866c0 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 9 Feb 2026 07:33:12 +0000 Subject: [PATCH] add cached subtask inference --- .../policies/pi05_full/modeling_pi05.py | 47 ++++++++++++------- .../policies/pi05_full/processor_pi05.py | 29 ++++++------ 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index ca66becb7..7d0f8a0a1 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -1809,9 +1809,34 @@ class PI05FullPolicy(PreTrainedPolicy): self.eval() + # generate subtask tokens with time-based caching (independent of action queue) + # only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate) + current_time = time.time() + interval = self.config.subtask_regeneration_interval + should_regenerate = ( + self._cached_subtask_tokens is None + or self._last_subtask_time is None + or interval <= 0 # 0 means regenerate every call + or (current_time - self._last_subtask_time) >= interval + ) + + if should_regenerate: + images, img_masks = self._preprocess_images(batch) + high_level_task_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + high_level_task_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + subtask_tokens, subtask_masks = self.model.generate_subtask_tokens( + images, img_masks, high_level_task_tokens, high_level_task_masks, + max_decoding_steps=self.config.tokenizer_max_length + ) + self._cached_subtask_tokens = subtask_tokens + self._cached_subtask_masks = subtask_masks + self._last_subtask_time = current_time + # log and decode the generate subtask tokens + print(f"Generated subtask tokens: {self.model._paligemma_tokenizer.decode(subtask_tokens[0].tolist(), skip_special_tokens=True)}") + # REMOVE + # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - # TODO: jadechoghari, generate subtask tokens here - ideally every 1 second actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) @@ -1825,28 +1850,18 @@ class PI05FullPolicy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - # tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"] + high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - # Generate subtask tokens with time-based caching - # Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate) - current_time = time.time() - interval = self.config.subtask_regeneration_interval - should_regenerate = ( - self._cached_subtask_tokens is None - or self._last_subtask_time is None - or interval <= 0 # 0 means regenerate every call - or (current_time - self._last_subtask_time) >= interval - ) - - if should_regenerate: + # Use cached subtask tokens (generated in select_action based on time interval) + # If called directly without select_action, generate subtask tokens + if self._cached_subtask_tokens is None: subtask_tokens, subtask_masks = self.model.generate_subtask_tokens( images, img_masks, high_level_task_tokens, high_level_task_masks, max_decoding_steps=self.config.tokenizer_max_length ) self._cached_subtask_tokens = subtask_tokens self._cached_subtask_masks = subtask_masks - self._last_subtask_time = current_time + self._last_subtask_time = time.time() else: subtask_tokens = self._cached_subtask_tokens subtask_masks = self._cached_subtask_masks diff --git a/src/lerobot/policies/pi05_full/processor_pi05.py b/src/lerobot/policies/pi05_full/processor_pi05.py index 80059e9c9..43b643f0b 100644 --- a/src/lerobot/policies/pi05_full/processor_pi05.py +++ b/src/lerobot/policies/pi05_full/processor_pi05.py @@ -54,8 +54,8 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): """ max_state_dim: int = 32 - user_prompt_key: str = "task" - command_key: str = "subtask" + task_key: str = "task" + subtask_key: str = "subtask" def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() @@ -63,12 +63,10 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) if state is None: raise ValueError("State is required for PI05") - user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key) + user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) if user_prompts is None: raise ValueError("No user prompts found in complementary data") - commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key) - if commands is None: - raise ValueError("No commands found in complementary data") + commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key) # TODO: check if this necessary state = deepcopy(state) @@ -89,17 +87,18 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" full_prompts.append(full_prompt) - transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts - # process commands - full_commands = [] - for i, command in enumerate(commands): - cleaned_text = command.strip().replace("_", " ").replace("\n", " ") - cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari) - full_command = f"Subtask: {cleaned_text};\n" - full_commands.append(full_command) + # process commands (optional) + if commands is not None: + full_commands = [] + for i, command in enumerate(commands): + cleaned_text = command.strip().replace("_", " ").replace("\n", " ") + cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari) + full_command = f"Subtask: {cleaned_text};\n" + full_commands.append(full_command) - transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands + transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands # note: action tokens will be processed in the ActionTokenizerProcessorStep # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)