mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
add cached subtask inference
This commit is contained in:
@@ -1809,9 +1809,34 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
|
||||
self.eval()
|
||||
|
||||
# generate subtask tokens with time-based caching (independent of action queue)
|
||||
# only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
|
||||
current_time = time.time()
|
||||
interval = self.config.subtask_regeneration_interval
|
||||
should_regenerate = (
|
||||
self._cached_subtask_tokens is None
|
||||
or self._last_subtask_time is None
|
||||
or interval <= 0 # 0 means regenerate every call
|
||||
or (current_time - self._last_subtask_time) >= interval
|
||||
)
|
||||
|
||||
if should_regenerate:
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
high_level_task_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
|
||||
images, img_masks, high_level_task_tokens, high_level_task_masks,
|
||||
max_decoding_steps=self.config.tokenizer_max_length
|
||||
)
|
||||
self._cached_subtask_tokens = subtask_tokens
|
||||
self._cached_subtask_masks = subtask_masks
|
||||
self._last_subtask_time = current_time
|
||||
# log and decode the generate subtask tokens
|
||||
print(f"Generated subtask tokens: {self.model._paligemma_tokenizer.decode(subtask_tokens[0].tolist(), skip_special_tokens=True)}")
|
||||
# REMOVE
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
if len(self._action_queue) == 0:
|
||||
# TODO: jadechoghari, generate subtask tokens here - ideally every 1 second
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
@@ -1825,28 +1850,18 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
# tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
|
||||
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Generate subtask tokens with time-based caching
|
||||
# Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
|
||||
current_time = time.time()
|
||||
interval = self.config.subtask_regeneration_interval
|
||||
should_regenerate = (
|
||||
self._cached_subtask_tokens is None
|
||||
or self._last_subtask_time is None
|
||||
or interval <= 0 # 0 means regenerate every call
|
||||
or (current_time - self._last_subtask_time) >= interval
|
||||
)
|
||||
|
||||
if should_regenerate:
|
||||
# Use cached subtask tokens (generated in select_action based on time interval)
|
||||
# If called directly without select_action, generate subtask tokens
|
||||
if self._cached_subtask_tokens is None:
|
||||
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
|
||||
images, img_masks, high_level_task_tokens, high_level_task_masks,
|
||||
max_decoding_steps=self.config.tokenizer_max_length
|
||||
)
|
||||
self._cached_subtask_tokens = subtask_tokens
|
||||
self._cached_subtask_masks = subtask_masks
|
||||
self._last_subtask_time = current_time
|
||||
self._last_subtask_time = time.time()
|
||||
else:
|
||||
subtask_tokens = self._cached_subtask_tokens
|
||||
subtask_masks = self._cached_subtask_masks
|
||||
|
||||
@@ -54,8 +54,8 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
user_prompt_key: str = "task"
|
||||
command_key: str = "subtask"
|
||||
task_key: str = "task"
|
||||
subtask_key: str = "subtask"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
@@ -63,12 +63,10 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key)
|
||||
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if user_prompts is None:
|
||||
raise ValueError("No user prompts found in complementary data")
|
||||
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key)
|
||||
if commands is None:
|
||||
raise ValueError("No commands found in complementary data")
|
||||
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key)
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
@@ -89,17 +87,18 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
|
||||
# process commands
|
||||
full_commands = []
|
||||
for i, command in enumerate(commands):
|
||||
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
full_command = f"Subtask: {cleaned_text};\n"
|
||||
full_commands.append(full_command)
|
||||
# process commands (optional)
|
||||
if commands is not None:
|
||||
full_commands = []
|
||||
for i, command in enumerate(commands):
|
||||
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
full_command = f"Subtask: {cleaned_text};\n"
|
||||
full_commands.append(full_command)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands
|
||||
|
||||
# note: action tokens will be processed in the ActionTokenizerProcessorStep
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
|
||||
Reference in New Issue
Block a user