run inference, attention mask

This commit is contained in:
Jade Choghari
2026-01-14 11:52:31 +00:00
parent 72f7aaedb5
commit b57504b89e
9 changed files with 384 additions and 49 deletions
+8
View File
@@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -390,6 +391,13 @@ def make_pre_post_processors(
config=policy_cfg, config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"), dataset_stats=kwargs.get("dataset_stats"),
) )
elif isinstance(policy_cfg, PI05FullConfig):
from lerobot.policies.pi05_full.processor_pi05 import make_pi05_full_pre_post_processors
processors = make_pi05_full_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else: else:
try: try:
+4 -4
View File
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .configuration_pi05 import PI05Config from .configuration_pi05 import PI05FullConfig
from .modeling_pi05 import PI05Policy from .modeling_pi05 import PI05FullPolicy
from .processor_pi05 import make_pi05_pre_post_processors from .processor_pi05 import make_pi05_full_pre_post_processors
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] __all__ = ["PI05FullConfig", "PI05FullPolicy", "make_pi05_full_pre_post_processors"]
@@ -26,9 +26,9 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05") @PreTrainedConfig.register_subclass("pi05_full")
@dataclass @dataclass
class PI05Config(PreTrainedConfig): class PI05FullConfig(PreTrainedConfig):
paligemma_variant: str = "gemma_2b" paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m" action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32" dtype: str = "float32" # Options: "bfloat16", "float32"
@@ -71,6 +71,11 @@ class PI05Config(PreTrainedConfig):
} }
) )
action_tokenizer_name: str = "physical-intelligence/fast"
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
max_action_tokens: int = 256
fast_skip_tokens: int = 128
# Training settings # Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization compile_model: bool = False # Whether to use torch.compile for model optimization
@@ -0,0 +1,91 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# import make_pre_post_processors
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.configs.policies import PreTrainedConfig
cfg = PreTrainedConfig.from_pretrained(
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
)
cfg.dtype = "bfloat16"
pre_processor, post_processor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path="/fsx/jade_choghari/models/pi05-base",
)
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
# rename map --rename_map='{
# "observation.images.side": "observation.images.base_0_rgb",
# "observation.images.up": "observation.images.left_wrist_0_rgb"
# }'
rename_map = {
"observation.images.side": "observation.images.base_0_rgb",
"observation.images.up": "observation.images.left_wrist_0_rgb"
}
policy = make_policy(
cfg=cfg,
ds_meta=dataset.meta,
rename_map=rename_map,
)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
batch = pre_processor(batch)
policy.train()
# run inference
# action = policy.select_action(batch)
loss, loss_dict = policy.forward(batch)
breakpoint()
# import requests
# from PIL import Image
# from transformers import AutoProcessor
# model = policy.model.paligemma_with_expert.paligemma
# model = model.to(device="cuda", dtype=torch.bfloat16)
# model.eval()
# prompt = "Describe this image."
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
# image = Image.open(requests.get(url, stream=True).raw)
# processor = AutoProcessor.from_pretrained(
# "google/paligemma-3b-pt-224",
# )
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
# print("generating...")
# output = model.generate(
# **inputs,
# max_new_tokens=50,
# use_cache=True, # default dynamic cache
# )
# print(processor.decode(output[0], skip_special_tokens=True))
# # other model
# from transformers import PaliGemmaForConditionalGeneration
# model = PaliGemmaForConditionalGeneration.from_pretrained(
# "google/paligemma2-3b-pt-224",
# torch_dtype=torch.bfloat16,
# device_map="auto",
# )
# model.eval()
# print("generating...")
# output = model.generate(
# **inputs,
# max_new_tokens=100,
# use_cache=True, # default dynamic cache
# )
# print("Model 2 output:")
# print(processor.decode(output[0], skip_special_tokens=True))
+168 -20
View File
@@ -41,7 +41,7 @@ else:
PaliGemmaForConditionalGeneration = None PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pi05_full.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05FullConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import ( from lerobot.utils.constants import (
@@ -49,6 +49,10 @@ from lerobot.utils.constants import (
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS, OBS_LANGUAGE_TOKENS,
OPENPI_ATTENTION_MASK_VALUE, OPENPI_ATTENTION_MASK_VALUE,
OBS_LANGUAGE_USER_PROMPT_TOKENS,
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
ACTION_TOKENS,
ACTION_TOKEN_MASK,
) )
@@ -534,7 +538,7 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model.""" """Core PI05 PyTorch model."""
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): def __init__(self, config: PI05FullConfig, rtc_processor: RTCProcessor | None = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.rtc_processor = rtc_processor self.rtc_processor = rtc_processor
@@ -631,13 +635,104 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device) return time.to(dtype=torch.float32, device=device)
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
"""Create custom 2D attention mask for the new attention pattern.
Attention rules:
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
- FAST: attend to images + language + subtask, causal among themselves
Args:
att_mask_segments: List of (type, length) tuples
pad_masks: Padding masks [B, total_seq_len]
bsize: Batch size
Returns:
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
"""
total_len = sum(length for _, length in att_mask_segments)
device = pad_masks.device
# start initializing attention mask as False (cannot attend)
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
# track positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# apply attention rules
for i, (query_type, query_start, query_end) in enumerate(positions):
for j, (key_type, key_start, key_end) in enumerate(positions):
# Images and Language can attend to each other bidirectionally
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend to images + language
elif query_type == 'subtask' and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend causally to themselves
elif query_type == 'subtask' and key_type == 'subtask':
# create causal mask for subtask tokens
subtask_len = query_end - query_start
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# FAST tokens attend to images + language + subtask
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# FAST tokens attend causally to themselves
elif query_type == 'fast' and key_type == 'fast':
fast_len = query_end - query_start
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# apply padding masks
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def embed_prefix( def embed_prefix(
self, images, img_masks, tokens, masks self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: images,
"""Embed images with SigLIP and language tokens with embedding layer.""" img_masks,
tokens,
subtask_tokens,
masks,
subtask_masks,
fast_action_tokens=None,
fast_action_masks=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
Args:
images: List of image tensors
img_masks: List of image masks
tokens: Language instruction tokens
subtask_tokens: Subtask tokens to predict (can be None for inference)
masks: Attention masks for tokens
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
fast_action_masks: Padding masks for FAST action tokens (can be None)
Returns:
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
pad_masks: Padding masks
att_masks: Custom 2D attention mask implementing the required pattern
total_T_images: Total number of image tokens
num_subtask_embs: Number of subtask token embeddings
num_fast_embs: Number of FAST action token embeddings
"""
embs = [] embs = []
pad_masks = [] pad_masks = []
att_masks = [] att_mask_segments = [] # Store info about each segment for custom mask creation
total_T_images = 0
num_subtask_embs = 0
num_fast_embs = 0
# Process images # Process images
for img, img_mask in zip(images, img_masks, strict=True): for img, img_mask in zip(images, img_masks, strict=True):
@@ -650,9 +745,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
embs.append(img_emb) embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_masks += [0] * num_img_embs att_mask_segments.append(('image', num_img_embs))
total_T_images += num_img_embs
# Process language tokens # Process language instruction tokens
def lang_embed_func(tokens): def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1] lang_emb_dim = lang_emb.shape[-1]
@@ -663,14 +759,57 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
pad_masks.append(masks) pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1] num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs att_mask_segments.append(('language', num_lang_embs))
# process subtask tokens (these are predicted, so use causal masking)
if subtask_tokens is not None:
def subtask_embed_func(subtask_tokens):
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
subtask_emb_dim = subtask_emb.shape[-1]
return subtask_emb * math.sqrt(subtask_emb_dim)
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
embs.append(subtask_emb)
# create subtask pad masks (non-zero tokens are valid)
pad_masks.append(subtask_masks)
num_subtask_embs = subtask_emb.shape[1]
att_mask_segments.append(('subtask', num_subtask_embs))
# Process FAST action tokens (discrete token IDs)
if fast_action_tokens is not None:
def fast_action_embed_func(fast_action_tokens):
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
num_fast_embs = fast_action_tokens.shape[1]
pad_masks.append(fast_action_masks)
att_mask_segments.append(("fast", num_fast_embs))
embs = torch.cat(embs, dim=1) embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1) pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
bsize = pad_masks.shape[0] # create custom 2D attention mask
att_masks = att_masks[None, :].expand(bsize, len(att_masks)) # Attention rules:
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
# - FAST: attend to images + language + subtask, causal among themselves
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
# # Optionally visualize the attention mask
# self.visualize_attention_mask(
# att_mask_segments=att_mask_segments,
# att_2d_masks=att_masks,
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
# batch_idx=0,
# max_display_tokens=512 # Limit display for very long sequences
# )
return embs, pad_masks, att_masks return embs, pad_masks, att_masks
@@ -721,7 +860,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return embs, pad_masks, att_masks, adarms_cond return embs, pad_masks, att_masks, adarms_cond
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: def forward(self, images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss.""" """Do a full training forward pass and compute the loss."""
if noise is None: if noise is None:
noise = self.sample_noise(actions.shape, actions.device) noise = self.sample_noise(actions.shape, actions.device)
@@ -733,9 +872,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
#TODO jadechoghari
# this attention part should be reworked
breakpoint()
if ( if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16 == torch.bfloat16
@@ -762,6 +904,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
) )
return suffix_out return suffix_out
# TODO: jadechoghri
# add subtask prediction loss
# add fast action prediction loss
suffix_out = self._apply_checkpoint( suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
) )
@@ -895,15 +1040,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return self.action_out_proj(suffix_out) return self.action_out_proj(suffix_out)
class PI05Policy(PreTrainedPolicy): class PI05FullPolicy(PreTrainedPolicy):
"""PI05 Policy for LeRobot.""" """PI05 Policy for LeRobot."""
config_class = PI05Config config_class = PI05FullConfig
name = "pi05" name = "pi05_full"
def __init__( def __init__(
self, self,
config: PI05Config, config: PI05FullConfig,
**kwargs, **kwargs,
): ):
""" """
@@ -1245,12 +1390,15 @@ class PI05Policy(PreTrainedPolicy):
""" """
# Prepare inputs # Prepare inputs
images, img_masks = self._preprocess_images(batch) images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"]
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05) # Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions) losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions)
# Truncate losses to actual action dimensions # Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -22,9 +22,10 @@ import numpy as np
import torch import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
from lerobot.policies.pi05.modeling_pi05 import pad_vector from lerobot.policies.pi05_full.modeling_pi05 import pad_vector
from lerobot.processor import ( from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
DeviceProcessorStep, DeviceProcessorStep,
NormalizerProcessorStep, NormalizerProcessorStep,
@@ -45,15 +46,16 @@ from lerobot.utils.constants import (
) )
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") @ProcessorStepRegistry.register(name="pi05_full_prepare_state_tokenizer_processor_step")
@dataclass @dataclass
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
""" """
Processor step to prepare the state and tokenize the language input. Processor step to prepare the state and tokenize the language input.
""" """
max_state_dim: int = 32 max_state_dim: int = 32
task_key: str = "task" user_prompt_key: str = "user_prompt"
command_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy() transition = transition.copy()
@@ -61,9 +63,12 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None: if state is None:
raise ValueError("State is required for PI05") raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key)
if tasks is None: if user_prompts is None:
raise ValueError("No task found in complementary data") raise ValueError("No user prompts found in complementary data")
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key)
if commands is None:
raise ValueError("No commands found in complementary data")
# TODO: check if this necessary # TODO: check if this necessary
state = deepcopy(state) state = deepcopy(state)
@@ -77,13 +82,24 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = [] full_prompts = []
for i, task in enumerate(tasks): for i, user_prompt in enumerate(user_prompts):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ") cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i])) state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " full_prompt = f"User prompt: {cleaned_text}, State: {state_str};\n"
full_prompts.append(full_prompt) full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts
# process commands
full_commands = []
for i, command in enumerate(commands):
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
full_command = f"Subtask: {cleaned_text};\n"
full_commands.append(full_command)
transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands
# note: action tokens will be processed in the ActionTokenizerProcessorStep
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition return transition
@@ -97,8 +113,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
return features return features
def make_pi05_pre_post_processors( def make_pi05_full_pre_post_processors(
config: PI05Config, config: PI05FullConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[ ) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -140,13 +156,19 @@ def make_pi05_pre_post_processors(
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), Pi05FullPrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep( TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224", tokenizer_name=config.text_tokenizer_name,
max_length=config.tokenizer_max_length, max_length=config.tokenizer_max_length,
padding_side="right", padding_side="right",
padding="max_length", padding="max_length",
), ),
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name=config.text_tokenizer_name,
),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
] ]
+3 -1
View File
@@ -170,9 +170,11 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
task_key = {"task": batch["task"]} if "task" in batch else {} task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key, **user_prompt_key, **subtask_key}
def create_transition( def create_transition(
+59 -3
View File
@@ -35,6 +35,9 @@ from lerobot.utils.constants import (
ACTION_TOKENS, ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS, OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_USER_PROMPT,
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
OBS_LANGUAGE_USER_PROMPT_TOKENS,
) )
from lerobot.utils.import_utils import _transformers_available from lerobot.utils.import_utils import _transformers_available
@@ -139,18 +142,44 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None return None
def get_user_prompt(self, transition: EnvTransition) -> list[str] | None:
"""
Extracts the user_prompt from the transition's complementary data.
Args:
transition: The environment transition.
Returns:
A list of user_prompt strings, or None if the user_prompt key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
user_prompt = complementary_data.get("user_prompt")
if user_prompt is None:
return None
# Standardize to a list of strings for the tokenizer
if isinstance(user_prompt, str):
return [user_prompt]
elif isinstance(user_prompt, list) and all(isinstance(t, str) for t in user_prompt):
return user_prompt
return None
def observation(self, observation: RobotObservation) -> RobotObservation: def observation(self, observation: RobotObservation) -> RobotObservation:
""" """
Tokenizes the task description and adds it to the observation dictionary. Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
This method retrieves the task, tokenizes it, moves the resulting tensors to the This method retrieves the task and user_prompt, tokenizes them, moves the resulting tensors to the
same device as other data in the transition, and updates the observation. same device as other data in the transition, and updates the observation.
Args: Args:
observation: The original observation dictionary. observation: The original observation dictionary.
Returns: Returns:
The updated observation dictionary including token IDs and an attention mask. The updated observation dictionary including token IDs and attention masks.
""" """
task = self.get_task(self.transition) task = self.get_task(self.transition)
if task is None: if task is None:
@@ -176,6 +205,22 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize user_prompt if available
user_prompt = self.get_user_prompt(self.transition)
if user_prompt is not None:
tokenized_user_prompt = self._tokenize_text(user_prompt)
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_user_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_user_prompt.items()
}
# Add tokenized user_prompt to the observation
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
return new_observation return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None: def _detect_device(self, transition: EnvTransition) -> torch.device | None:
@@ -274,6 +319,17 @@ class TokenizerProcessorStep(ObservationProcessorStep):
type=FeatureType.LANGUAGE, shape=(self.max_length,) type=FeatureType.LANGUAGE, shape=(self.max_length,)
) )
# Add features for user_prompt tokens and attention mask if they don't already exist
if OBS_LANGUAGE_USER_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_TOKENS] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
if OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
return features return features
+3
View File
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt"
OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens"
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask"
ACTION = "action" ACTION = "action"
ACTION_PREFIX = ACTION + "." ACTION_PREFIX = ACTION + "."