add cached subtask inference

This commit is contained in:
Jade Choghari
2026-02-09 07:33:12 +00:00
parent 6c94fcd1b1
commit 0059ca7924
2 changed files with 45 additions and 31 deletions
+31 -16
View File
@@ -1809,9 +1809,34 @@ class PI05FullPolicy(PreTrainedPolicy):
self.eval()
# generate subtask tokens with time-based caching (independent of action queue)
# only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
current_time = time.time()
interval = self.config.subtask_regeneration_interval
should_regenerate = (
self._cached_subtask_tokens is None
or self._last_subtask_time is None
or interval <= 0 # 0 means regenerate every call
or (current_time - self._last_subtask_time) >= interval
)
if should_regenerate:
images, img_masks = self._preprocess_images(batch)
high_level_task_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
high_level_task_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
images, img_masks, high_level_task_tokens, high_level_task_masks,
max_decoding_steps=self.config.tokenizer_max_length
)
self._cached_subtask_tokens = subtask_tokens
self._cached_subtask_masks = subtask_masks
self._last_subtask_time = current_time
# log and decode the generate subtask tokens
print(f"Generated subtask tokens: {self.model._paligemma_tokenizer.decode(subtask_tokens[0].tolist(), skip_special_tokens=True)}")
# REMOVE
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
# TODO: jadechoghari, generate subtask tokens here - ideally every 1 second
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))
@@ -1825,28 +1850,18 @@ class PI05FullPolicy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
# tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Generate subtask tokens with time-based caching
# Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
current_time = time.time()
interval = self.config.subtask_regeneration_interval
should_regenerate = (
self._cached_subtask_tokens is None
or self._last_subtask_time is None
or interval <= 0 # 0 means regenerate every call
or (current_time - self._last_subtask_time) >= interval
)
if should_regenerate:
# Use cached subtask tokens (generated in select_action based on time interval)
# If called directly without select_action, generate subtask tokens
if self._cached_subtask_tokens is None:
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
images, img_masks, high_level_task_tokens, high_level_task_masks,
max_decoding_steps=self.config.tokenizer_max_length
)
self._cached_subtask_tokens = subtask_tokens
self._cached_subtask_masks = subtask_masks
self._last_subtask_time = current_time
self._last_subtask_time = time.time()
else:
subtask_tokens = self._cached_subtask_tokens
subtask_masks = self._cached_subtask_masks
@@ -54,8 +54,8 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
"""
max_state_dim: int = 32
user_prompt_key: str = "task"
command_key: str = "subtask"
task_key: str = "task"
subtask_key: str = "subtask"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
@@ -63,12 +63,10 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key)
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if user_prompts is None:
raise ValueError("No user prompts found in complementary data")
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key)
if commands is None:
raise ValueError("No commands found in complementary data")
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key)
# TODO: check if this necessary
state = deepcopy(state)
@@ -89,17 +87,18 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# process commands
full_commands = []
for i, command in enumerate(commands):
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
full_command = f"Subtask: {cleaned_text};\n"
full_commands.append(full_command)
# process commands (optional)
if commands is not None:
full_commands = []
for i, command in enumerate(commands):
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
full_command = f"Subtask: {cleaned_text};\n"
full_commands.append(full_command)
transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands
# note: action tokens will be processed in the ActionTokenizerProcessorStep
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)