add generation inference for subtask

This commit is contained in:
Jade Choghari
2026-01-27 16:21:44 +00:00
parent 6a6912ec37
commit 99dbbd56c2
2 changed files with 29 additions and 4 deletions
@@ -77,6 +77,7 @@ class PI05FullConfig(PreTrainedConfig):
# subtask stuff
max_decoding_steps: int = 200
temperature: float = 0.0
subtask_regeneration_interval: float = 1.0 # Regenerate subtask tokens every N seconds (0 = every call)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
@@ -17,6 +17,7 @@
import builtins
import logging
import math
import time
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
@@ -1613,6 +1614,10 @@ class PI05FullPolicy(PreTrainedPolicy):
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
# Subtask caching state - regenerate every `subtask_regeneration_interval` seconds
self._cached_subtask_tokens: Tensor | None = None
self._cached_subtask_masks: Tensor | None = None
self._last_subtask_time: float | None = None
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
@@ -1728,12 +1733,31 @@ class PI05FullPolicy(PreTrainedPolicy):
# tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
# we will need to generate subtask tokens here - ideally every 1 second
# TODO: jadechoghari: this should be called every 1 second or when the user input a prompt
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(images, img_masks, high_level_task_tokens, high_level_task_masks, max_decoding_steps=self.config.tokenizer_max_length)
# Generate subtask tokens with time-based caching
# Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
current_time = time.time()
interval = self.config.subtask_regeneration_interval
should_regenerate = (
self._cached_subtask_tokens is None
or self._last_subtask_time is None
or interval <= 0 # 0 means regenerate every call
or (current_time - self._last_subtask_time) >= interval
)
if should_regenerate:
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
images, img_masks, high_level_task_tokens, high_level_task_masks,
max_decoding_steps=self.config.tokenizer_max_length
)
self._cached_subtask_tokens = subtask_tokens
self._cached_subtask_masks = subtask_masks
self._last_subtask_time = current_time
else:
subtask_tokens = self._cached_subtask_tokens
subtask_masks = self._cached_subtask_masks
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, **kwargs)
breakpoint()
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]