diff --git a/src/lerobot/policies/pi05_full/configuration_pi05.py b/src/lerobot/policies/pi05_full/configuration_pi05.py index fa91e3edb..a95645220 100644 --- a/src/lerobot/policies/pi05_full/configuration_pi05.py +++ b/src/lerobot/policies/pi05_full/configuration_pi05.py @@ -77,6 +77,7 @@ class PI05FullConfig(PreTrainedConfig): # subtask stuff max_decoding_steps: int = 200 temperature: float = 0.0 + subtask_regeneration_interval: float = 1.0 # Regenerate subtask tokens every N seconds (0 = every call) # Training settings gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index b9d41b3de..996ef2afe 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -17,6 +17,7 @@ import builtins import logging import math +import time from collections import deque from pathlib import Path from typing import TYPE_CHECKING, Literal, TypedDict @@ -1613,6 +1614,10 @@ class PI05FullPolicy(PreTrainedPolicy): self._queues = { ACTION: deque(maxlen=self.config.n_action_steps), } + # Subtask caching state - regenerate every `subtask_regeneration_interval` seconds + self._cached_subtask_tokens: Tensor | None = None + self._cached_subtask_masks: Tensor | None = None + self._last_subtask_time: float | None = None def init_rtc_processor(self): """Initialize RTC processor if RTC is enabled in config.""" @@ -1728,12 +1733,31 @@ class PI05FullPolicy(PreTrainedPolicy): # tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"] - # we will need to generate subtask tokens here - ideally every 1 second - # TODO: jadechoghari: this should be called every 1 second or when the user input a prompt - subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(images, img_masks, high_level_task_tokens, high_level_task_masks, max_decoding_steps=self.config.tokenizer_max_length) + # Generate subtask tokens with time-based caching + # Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate) + current_time = time.time() + interval = self.config.subtask_regeneration_interval + should_regenerate = ( + self._cached_subtask_tokens is None + or self._last_subtask_time is None + or interval <= 0 # 0 means regenerate every call + or (current_time - self._last_subtask_time) >= interval + ) + + if should_regenerate: + subtask_tokens, subtask_masks = self.model.generate_subtask_tokens( + images, img_masks, high_level_task_tokens, high_level_task_masks, + max_decoding_steps=self.config.tokenizer_max_length + ) + self._cached_subtask_tokens = subtask_tokens + self._cached_subtask_masks = subtask_masks + self._last_subtask_time = current_time + else: + subtask_tokens = self._cached_subtask_tokens + subtask_masks = self._cached_subtask_masks + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) actions = self.model.sample_actions(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, **kwargs) - breakpoint() # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim]