This commit is contained in:
Jade Choghari
2026-01-15 16:35:58 +00:00
parent 6e88d6f387
commit 966fedfeef
3 changed files with 9 additions and 1 deletions
@@ -20,4 +20,12 @@ print(batch['task_index_high_level'])
print(batch['user_prompt'][0])
print(batch['robot_utterance'][0])
print(batch['task'][0])
# read this parquet /fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquett
import pandas as pd
tasks_df = pd.read_parquet('/fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquet')
# print all
print(tasks_df.columns)
breakpoint()
@@ -1638,7 +1638,7 @@ class PI05FullPolicy(PreTrainedPolicy):
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(images, img_masks, high_level_task_tokens, high_level_task_masks, max_decoding_steps=self.config.tokenizer_max_length)
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, **kwargs)
breakpoint()
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]