mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
run inference, attention mask
This commit is contained in:
@@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
@@ -390,6 +391,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, PI05FullConfig):
|
||||
from lerobot.policies.pi05_full.processor_pi05 import make_pi05_full_pre_post_processors
|
||||
|
||||
processors = make_pi05_full_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi05 import PI05Config
|
||||
from .modeling_pi05 import PI05Policy
|
||||
from .processor_pi05 import make_pi05_pre_post_processors
|
||||
from .configuration_pi05 import PI05FullConfig
|
||||
from .modeling_pi05 import PI05FullPolicy
|
||||
from .processor_pi05 import make_pi05_full_pre_post_processors
|
||||
|
||||
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
|
||||
__all__ = ["PI05FullConfig", "PI05FullPolicy", "make_pi05_full_pre_post_processors"]
|
||||
|
||||
@@ -26,9 +26,9 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@PreTrainedConfig.register_subclass("pi05_full")
|
||||
@dataclass
|
||||
class PI05Config(PreTrainedConfig):
|
||||
class PI05FullConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
@@ -71,6 +71,11 @@ class PI05Config(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_action_tokens: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
# import make_pre_post_processors
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
# rename map --rename_map='{
|
||||
# "observation.images.side": "observation.images.base_0_rgb",
|
||||
# "observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
# }'
|
||||
rename_map = {
|
||||
"observation.images.side": "observation.images.base_0_rgb",
|
||||
"observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
}
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=rename_map,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
batch = pre_processor(batch)
|
||||
policy.train()
|
||||
# run inference
|
||||
# action = policy.select_action(batch)
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
breakpoint()
|
||||
# import requests
|
||||
# from PIL import Image
|
||||
# from transformers import AutoProcessor
|
||||
# model = policy.model.paligemma_with_expert.paligemma
|
||||
# model = model.to(device="cuda", dtype=torch.bfloat16)
|
||||
# model.eval()
|
||||
# prompt = "Describe this image."
|
||||
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
# processor = AutoProcessor.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# )
|
||||
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=50,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
# # other model
|
||||
# from transformers import PaliGemmaForConditionalGeneration
|
||||
# model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
# "google/paligemma2-3b-pt-224",
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device_map="auto",
|
||||
# )
|
||||
# model.eval()
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=100,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print("Model 2 output:")
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
@@ -41,7 +41,7 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05FullConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -49,6 +49,10 @@ from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||
ACTION_TOKENS,
|
||||
ACTION_TOKEN_MASK,
|
||||
)
|
||||
|
||||
|
||||
@@ -534,7 +538,7 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||
def __init__(self, config: PI05FullConfig, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
@@ -631,13 +635,104 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
|
||||
"""Create custom 2D attention mask for the new attention pattern.
|
||||
|
||||
Attention rules:
|
||||
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||
- FAST: attend to images + language + subtask, causal among themselves
|
||||
|
||||
Args:
|
||||
att_mask_segments: List of (type, length) tuples
|
||||
pad_masks: Padding masks [B, total_seq_len]
|
||||
bsize: Batch size
|
||||
|
||||
Returns:
|
||||
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
|
||||
"""
|
||||
total_len = sum(length for _, length in att_mask_segments)
|
||||
device = pad_masks.device
|
||||
|
||||
# start initializing attention mask as False (cannot attend)
|
||||
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||
|
||||
# track positions for each segment
|
||||
positions = []
|
||||
current_pos = 0
|
||||
for seg_type, seg_len in att_mask_segments:
|
||||
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||
current_pos += seg_len
|
||||
|
||||
# apply attention rules
|
||||
for i, (query_type, query_start, query_end) in enumerate(positions):
|
||||
for j, (key_type, key_start, key_end) in enumerate(positions):
|
||||
# Images and Language can attend to each other bidirectionally
|
||||
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# Subtask tokens attend to images + language
|
||||
elif query_type == 'subtask' and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# Subtask tokens attend causally to themselves
|
||||
elif query_type == 'subtask' and key_type == 'subtask':
|
||||
# create causal mask for subtask tokens
|
||||
subtask_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# FAST tokens attend to images + language + subtask
|
||||
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# FAST tokens attend causally to themselves
|
||||
elif query_type == 'fast' and key_type == 'fast':
|
||||
fast_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# apply padding masks
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
|
||||
return att_2d_masks
|
||||
def embed_prefix(
|
||||
self, images, img_masks, tokens, masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
subtask_tokens,
|
||||
masks,
|
||||
subtask_masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
tokens: Language instruction tokens
|
||||
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
||||
masks: Attention masks for tokens
|
||||
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
|
||||
fast_action_masks: Padding masks for FAST action tokens (can be None)
|
||||
|
||||
Returns:
|
||||
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
|
||||
pad_masks: Padding masks
|
||||
att_masks: Custom 2D attention mask implementing the required pattern
|
||||
total_T_images: Total number of image tokens
|
||||
num_subtask_embs: Number of subtask token embeddings
|
||||
num_fast_embs: Number of FAST action token embeddings
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
att_mask_segments = [] # Store info about each segment for custom mask creation
|
||||
total_T_images = 0
|
||||
num_subtask_embs = 0
|
||||
num_fast_embs = 0
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
@@ -650,9 +745,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
att_masks += [0] * num_img_embs
|
||||
att_mask_segments.append(('image', num_img_embs))
|
||||
total_T_images += num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
@@ -663,14 +759,57 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
pad_masks.append(masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
att_mask_segments.append(('language', num_lang_embs))
|
||||
|
||||
# process subtask tokens (these are predicted, so use causal masking)
|
||||
if subtask_tokens is not None:
|
||||
def subtask_embed_func(subtask_tokens):
|
||||
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
|
||||
subtask_emb_dim = subtask_emb.shape[-1]
|
||||
return subtask_emb * math.sqrt(subtask_emb_dim)
|
||||
|
||||
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
|
||||
embs.append(subtask_emb)
|
||||
|
||||
# create subtask pad masks (non-zero tokens are valid)
|
||||
pad_masks.append(subtask_masks)
|
||||
|
||||
num_subtask_embs = subtask_emb.shape[1]
|
||||
att_mask_segments.append(('subtask', num_subtask_embs))
|
||||
|
||||
# Process FAST action tokens (discrete token IDs)
|
||||
if fast_action_tokens is not None:
|
||||
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
pad_masks.append(fast_action_masks)
|
||||
att_mask_segments.append(("fast", num_fast_embs))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
# create custom 2D attention mask
|
||||
# Attention rules:
|
||||
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||
# - FAST: attend to images + language + subtask, causal among themselves
|
||||
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
|
||||
|
||||
# # Optionally visualize the attention mask
|
||||
# self.visualize_attention_mask(
|
||||
# att_mask_segments=att_mask_segments,
|
||||
# att_2d_masks=att_masks,
|
||||
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
|
||||
# batch_idx=0,
|
||||
# max_display_tokens=512 # Limit display for very long sequences
|
||||
# )
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
@@ -721,7 +860,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
def forward(self, images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
@@ -733,9 +872,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
#TODO jadechoghari
|
||||
# this attention part should be reworked
|
||||
breakpoint()
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
@@ -762,6 +904,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
# TODO: jadechoghri
|
||||
# add subtask prediction loss
|
||||
# add fast action prediction loss
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
@@ -895,15 +1040,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
|
||||
class PI05Policy(PreTrainedPolicy):
|
||||
class PI05FullPolicy(PreTrainedPolicy):
|
||||
"""PI05 Policy for LeRobot."""
|
||||
|
||||
config_class = PI05Config
|
||||
name = "pi05"
|
||||
config_class = PI05FullConfig
|
||||
name = "pi05_full"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI05Config,
|
||||
config: PI05FullConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -1245,12 +1390,15 @@ class PI05Policy(PreTrainedPolicy):
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
|
||||
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"]
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -22,9 +22,10 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||
from lerobot.policies.pi05_full.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
@@ -45,15 +46,16 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@ProcessorStepRegistry.register(name="pi05_full_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
user_prompt_key: str = "user_prompt"
|
||||
command_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
@@ -61,9 +63,12 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key)
|
||||
if user_prompts is None:
|
||||
raise ValueError("No user prompts found in complementary data")
|
||||
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key)
|
||||
if commands is None:
|
||||
raise ValueError("No commands found in complementary data")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
@@ -77,13 +82,24 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
for i, user_prompt in enumerate(user_prompts):
|
||||
cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompt = f"User prompt: {cleaned_text}, State: {state_str};\n"
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts
|
||||
|
||||
# process commands
|
||||
full_commands = []
|
||||
for i, command in enumerate(commands):
|
||||
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||
full_command = f"Subtask: {cleaned_text};\n"
|
||||
full_commands.append(full_command)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands
|
||||
|
||||
# note: action tokens will be processed in the ActionTokenizerProcessorStep
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
@@ -97,8 +113,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_pre_post_processors(
|
||||
config: PI05Config,
|
||||
def make_pi05_full_pre_post_processors(
|
||||
config: PI05FullConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -140,13 +156,19 @@ def make_pi05_pre_post_processors(
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
Pi05FullPrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
tokenizer_name=config.text_tokenizer_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name=config.text_tokenizer_name,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
|
||||
@@ -170,9 +170,11 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key, **user_prompt_key, **subtask_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -35,6 +35,9 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_USER_PROMPT,
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -139,18 +142,44 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_user_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the user_prompt from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of user_prompt strings, or None if the user_prompt key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
user_prompt = complementary_data.get("user_prompt")
|
||||
if user_prompt is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(user_prompt, str):
|
||||
return [user_prompt]
|
||||
elif isinstance(user_prompt, list) and all(isinstance(t, str) for t in user_prompt):
|
||||
return user_prompt
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
|
||||
|
||||
This method retrieves the task, tokenizes it, moves the resulting tensors to the
|
||||
This method retrieves the task and user_prompt, tokenizes them, moves the resulting tensors to the
|
||||
same device as other data in the transition, and updates the observation.
|
||||
|
||||
Args:
|
||||
observation: The original observation dictionary.
|
||||
|
||||
Returns:
|
||||
The updated observation dictionary including token IDs and an attention mask.
|
||||
The updated observation dictionary including token IDs and attention masks.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
@@ -176,6 +205,22 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize user_prompt if available
|
||||
user_prompt = self.get_user_prompt(self.transition)
|
||||
if user_prompt is not None:
|
||||
tokenized_user_prompt = self._tokenize_text(user_prompt)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_user_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_user_prompt.items()
|
||||
}
|
||||
|
||||
# Add tokenized user_prompt to the observation
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -274,6 +319,17 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for user_prompt tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_USER_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt"
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens"
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
|
||||
Reference in New Issue
Block a user