mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 00:27:15 +00:00
add generation inference for subtask
This commit is contained in:
@@ -77,6 +77,7 @@ class PI05FullConfig(PreTrainedConfig):
|
||||
# subtask stuff
|
||||
max_decoding_steps: int = 200
|
||||
temperature: float = 0.0
|
||||
subtask_regeneration_interval: float = 1.0 # Regenerate subtask tokens every N seconds (0 = every call)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
@@ -1613,6 +1614,10 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
# Subtask caching state - regenerate every `subtask_regeneration_interval` seconds
|
||||
self._cached_subtask_tokens: Tensor | None = None
|
||||
self._cached_subtask_masks: Tensor | None = None
|
||||
self._last_subtask_time: float | None = None
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
@@ -1728,12 +1733,31 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
# tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
|
||||
|
||||
# we will need to generate subtask tokens here - ideally every 1 second
|
||||
# TODO: jadechoghari: this should be called every 1 second or when the user input a prompt
|
||||
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(images, img_masks, high_level_task_tokens, high_level_task_masks, max_decoding_steps=self.config.tokenizer_max_length)
|
||||
# Generate subtask tokens with time-based caching
|
||||
# Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
|
||||
current_time = time.time()
|
||||
interval = self.config.subtask_regeneration_interval
|
||||
should_regenerate = (
|
||||
self._cached_subtask_tokens is None
|
||||
or self._last_subtask_time is None
|
||||
or interval <= 0 # 0 means regenerate every call
|
||||
or (current_time - self._last_subtask_time) >= interval
|
||||
)
|
||||
|
||||
if should_regenerate:
|
||||
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
|
||||
images, img_masks, high_level_task_tokens, high_level_task_masks,
|
||||
max_decoding_steps=self.config.tokenizer_max_length
|
||||
)
|
||||
self._cached_subtask_tokens = subtask_tokens
|
||||
self._cached_subtask_masks = subtask_masks
|
||||
self._last_subtask_time = current_time
|
||||
else:
|
||||
subtask_tokens = self._cached_subtask_tokens
|
||||
subtask_masks = self._cached_subtask_masks
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, **kwargs)
|
||||
breakpoint()
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
Reference in New Issue
Block a user