mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
run inference, attention mask
This commit is contained in:
@@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
|||||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
|
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
@@ -390,6 +391,13 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(policy_cfg, PI05FullConfig):
|
||||||
|
from lerobot.policies.pi05_full.processor_pi05 import make_pi05_full_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_pi05_full_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .configuration_pi05 import PI05Config
|
from .configuration_pi05 import PI05FullConfig
|
||||||
from .modeling_pi05 import PI05Policy
|
from .modeling_pi05 import PI05FullPolicy
|
||||||
from .processor_pi05 import make_pi05_pre_post_processors
|
from .processor_pi05 import make_pi05_full_pre_post_processors
|
||||||
|
|
||||||
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
|
__all__ = ["PI05FullConfig", "PI05FullPolicy", "make_pi05_full_pre_post_processors"]
|
||||||
|
|||||||
@@ -26,9 +26,9 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
|||||||
DEFAULT_IMAGE_SIZE = 224
|
DEFAULT_IMAGE_SIZE = 224
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi05")
|
@PreTrainedConfig.register_subclass("pi05_full")
|
||||||
@dataclass
|
@dataclass
|
||||||
class PI05Config(PreTrainedConfig):
|
class PI05FullConfig(PreTrainedConfig):
|
||||||
paligemma_variant: str = "gemma_2b"
|
paligemma_variant: str = "gemma_2b"
|
||||||
action_expert_variant: str = "gemma_300m"
|
action_expert_variant: str = "gemma_300m"
|
||||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||||
@@ -71,6 +71,11 @@ class PI05Config(PreTrainedConfig):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||||
|
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||||
|
max_action_tokens: int = 256
|
||||||
|
fast_skip_tokens: int = 128
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||||
|
|||||||
@@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
import lerobot
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
|
# import make_pre_post_processors
|
||||||
|
from lerobot.policies.factory import make_pre_post_processors
|
||||||
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
|
from lerobot.policies.factory import make_policy, make_policy_config
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
|
cfg = PreTrainedConfig.from_pretrained(
|
||||||
|
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||||
|
)
|
||||||
|
cfg.dtype = "bfloat16"
|
||||||
|
|
||||||
|
pre_processor, post_processor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg,
|
||||||
|
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||||
|
)
|
||||||
|
|
||||||
|
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||||
|
|
||||||
|
# rename map --rename_map='{
|
||||||
|
# "observation.images.side": "observation.images.base_0_rgb",
|
||||||
|
# "observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||||
|
# }'
|
||||||
|
rename_map = {
|
||||||
|
"observation.images.side": "observation.images.base_0_rgb",
|
||||||
|
"observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||||
|
}
|
||||||
|
policy = make_policy(
|
||||||
|
cfg=cfg,
|
||||||
|
ds_meta=dataset.meta,
|
||||||
|
rename_map=rename_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=4,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
batch = pre_processor(batch)
|
||||||
|
policy.train()
|
||||||
|
# run inference
|
||||||
|
# action = policy.select_action(batch)
|
||||||
|
loss, loss_dict = policy.forward(batch)
|
||||||
|
breakpoint()
|
||||||
|
# import requests
|
||||||
|
# from PIL import Image
|
||||||
|
# from transformers import AutoProcessor
|
||||||
|
# model = policy.model.paligemma_with_expert.paligemma
|
||||||
|
# model = model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
# model.eval()
|
||||||
|
# prompt = "Describe this image."
|
||||||
|
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||||
|
# image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
# processor = AutoProcessor.from_pretrained(
|
||||||
|
# "google/paligemma-3b-pt-224",
|
||||||
|
# )
|
||||||
|
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
||||||
|
# print("generating...")
|
||||||
|
# output = model.generate(
|
||||||
|
# **inputs,
|
||||||
|
# max_new_tokens=50,
|
||||||
|
# use_cache=True, # default dynamic cache
|
||||||
|
# )
|
||||||
|
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||||
|
|
||||||
|
|
||||||
|
# # other model
|
||||||
|
# from transformers import PaliGemmaForConditionalGeneration
|
||||||
|
# model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||||
|
# "google/paligemma2-3b-pt-224",
|
||||||
|
# torch_dtype=torch.bfloat16,
|
||||||
|
# device_map="auto",
|
||||||
|
# )
|
||||||
|
# model.eval()
|
||||||
|
# print("generating...")
|
||||||
|
# output = model.generate(
|
||||||
|
# **inputs,
|
||||||
|
# max_new_tokens=100,
|
||||||
|
# use_cache=True, # default dynamic cache
|
||||||
|
# )
|
||||||
|
# print("Model 2 output:")
|
||||||
|
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||||
@@ -41,7 +41,7 @@ else:
|
|||||||
PaliGemmaForConditionalGeneration = None
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
from lerobot.policies.pi05_full.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05FullConfig
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -49,6 +49,10 @@ from lerobot.utils.constants import (
|
|||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
OPENPI_ATTENTION_MASK_VALUE,
|
OPENPI_ATTENTION_MASK_VALUE,
|
||||||
|
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||||
|
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||||
|
ACTION_TOKENS,
|
||||||
|
ACTION_TOKEN_MASK,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -534,7 +538,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||||
"""Core PI05 PyTorch model."""
|
"""Core PI05 PyTorch model."""
|
||||||
|
|
||||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
def __init__(self, config: PI05FullConfig, rtc_processor: RTCProcessor | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rtc_processor = rtc_processor
|
self.rtc_processor = rtc_processor
|
||||||
@@ -631,13 +635,104 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||||
return time.to(dtype=torch.float32, device=device)
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
|
||||||
|
"""Create custom 2D attention mask for the new attention pattern.
|
||||||
|
|
||||||
|
Attention rules:
|
||||||
|
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||||
|
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||||
|
- FAST: attend to images + language + subtask, causal among themselves
|
||||||
|
|
||||||
|
Args:
|
||||||
|
att_mask_segments: List of (type, length) tuples
|
||||||
|
pad_masks: Padding masks [B, total_seq_len]
|
||||||
|
bsize: Batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
|
||||||
|
"""
|
||||||
|
total_len = sum(length for _, length in att_mask_segments)
|
||||||
|
device = pad_masks.device
|
||||||
|
|
||||||
|
# start initializing attention mask as False (cannot attend)
|
||||||
|
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
# track positions for each segment
|
||||||
|
positions = []
|
||||||
|
current_pos = 0
|
||||||
|
for seg_type, seg_len in att_mask_segments:
|
||||||
|
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||||
|
current_pos += seg_len
|
||||||
|
|
||||||
|
# apply attention rules
|
||||||
|
for i, (query_type, query_start, query_end) in enumerate(positions):
|
||||||
|
for j, (key_type, key_start, key_end) in enumerate(positions):
|
||||||
|
# Images and Language can attend to each other bidirectionally
|
||||||
|
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||||
|
|
||||||
|
# Subtask tokens attend to images + language
|
||||||
|
elif query_type == 'subtask' and key_type in ['image', 'language']:
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||||
|
|
||||||
|
# Subtask tokens attend causally to themselves
|
||||||
|
elif query_type == 'subtask' and key_type == 'subtask':
|
||||||
|
# create causal mask for subtask tokens
|
||||||
|
subtask_len = query_end - query_start
|
||||||
|
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||||
|
|
||||||
|
# FAST tokens attend to images + language + subtask
|
||||||
|
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||||
|
|
||||||
|
# FAST tokens attend causally to themselves
|
||||||
|
elif query_type == 'fast' and key_type == 'fast':
|
||||||
|
fast_len = query_end - query_start
|
||||||
|
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||||
|
|
||||||
|
# apply padding masks
|
||||||
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||||
|
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||||
|
|
||||||
|
return att_2d_masks
|
||||||
def embed_prefix(
|
def embed_prefix(
|
||||||
self, images, img_masks, tokens, masks
|
self,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
images,
|
||||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
img_masks,
|
||||||
|
tokens,
|
||||||
|
subtask_tokens,
|
||||||
|
masks,
|
||||||
|
subtask_masks,
|
||||||
|
fast_action_tokens=None,
|
||||||
|
fast_action_masks=None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||||
|
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of image tensors
|
||||||
|
img_masks: List of image masks
|
||||||
|
tokens: Language instruction tokens
|
||||||
|
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
||||||
|
masks: Attention masks for tokens
|
||||||
|
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
|
||||||
|
fast_action_masks: Padding masks for FAST action tokens (can be None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
|
||||||
|
pad_masks: Padding masks
|
||||||
|
att_masks: Custom 2D attention mask implementing the required pattern
|
||||||
|
total_T_images: Total number of image tokens
|
||||||
|
num_subtask_embs: Number of subtask token embeddings
|
||||||
|
num_fast_embs: Number of FAST action token embeddings
|
||||||
|
"""
|
||||||
embs = []
|
embs = []
|
||||||
pad_masks = []
|
pad_masks = []
|
||||||
att_masks = []
|
att_mask_segments = [] # Store info about each segment for custom mask creation
|
||||||
|
total_T_images = 0
|
||||||
|
num_subtask_embs = 0
|
||||||
|
num_fast_embs = 0
|
||||||
|
|
||||||
# Process images
|
# Process images
|
||||||
for img, img_mask in zip(images, img_masks, strict=True):
|
for img, img_mask in zip(images, img_masks, strict=True):
|
||||||
@@ -650,9 +745,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
embs.append(img_emb)
|
embs.append(img_emb)
|
||||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||||
att_masks += [0] * num_img_embs
|
att_mask_segments.append(('image', num_img_embs))
|
||||||
|
total_T_images += num_img_embs
|
||||||
|
|
||||||
# Process language tokens
|
# Process language instruction tokens
|
||||||
def lang_embed_func(tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
lang_emb_dim = lang_emb.shape[-1]
|
||||||
@@ -663,14 +759,57 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
pad_masks.append(masks)
|
pad_masks.append(masks)
|
||||||
|
|
||||||
num_lang_embs = lang_emb.shape[1]
|
num_lang_embs = lang_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs
|
att_mask_segments.append(('language', num_lang_embs))
|
||||||
|
|
||||||
|
# process subtask tokens (these are predicted, so use causal masking)
|
||||||
|
if subtask_tokens is not None:
|
||||||
|
def subtask_embed_func(subtask_tokens):
|
||||||
|
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
|
||||||
|
subtask_emb_dim = subtask_emb.shape[-1]
|
||||||
|
return subtask_emb * math.sqrt(subtask_emb_dim)
|
||||||
|
|
||||||
|
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
|
||||||
|
embs.append(subtask_emb)
|
||||||
|
|
||||||
|
# create subtask pad masks (non-zero tokens are valid)
|
||||||
|
pad_masks.append(subtask_masks)
|
||||||
|
|
||||||
|
num_subtask_embs = subtask_emb.shape[1]
|
||||||
|
att_mask_segments.append(('subtask', num_subtask_embs))
|
||||||
|
|
||||||
|
# Process FAST action tokens (discrete token IDs)
|
||||||
|
if fast_action_tokens is not None:
|
||||||
|
|
||||||
|
def fast_action_embed_func(fast_action_tokens):
|
||||||
|
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||||
|
fast_emb_dim = fast_emb.shape[-1]
|
||||||
|
return fast_emb * math.sqrt(fast_emb_dim)
|
||||||
|
|
||||||
|
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||||
|
embs.append(fast_action_emb)
|
||||||
|
|
||||||
|
num_fast_embs = fast_action_tokens.shape[1]
|
||||||
|
pad_masks.append(fast_action_masks)
|
||||||
|
att_mask_segments.append(("fast", num_fast_embs))
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
||||||
|
|
||||||
bsize = pad_masks.shape[0]
|
# create custom 2D attention mask
|
||||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
# Attention rules:
|
||||||
|
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||||
|
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||||
|
# - FAST: attend to images + language + subtask, causal among themselves
|
||||||
|
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
|
||||||
|
|
||||||
|
# # Optionally visualize the attention mask
|
||||||
|
# self.visualize_attention_mask(
|
||||||
|
# att_mask_segments=att_mask_segments,
|
||||||
|
# att_2d_masks=att_masks,
|
||||||
|
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
|
||||||
|
# batch_idx=0,
|
||||||
|
# max_display_tokens=512 # Limit display for very long sequences
|
||||||
|
# )
|
||||||
|
|
||||||
return embs, pad_masks, att_masks
|
return embs, pad_masks, att_masks
|
||||||
|
|
||||||
@@ -721,7 +860,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
def forward(self, images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions, noise=None, time=None) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
noise = self.sample_noise(actions.shape, actions.device)
|
||||||
@@ -733,9 +872,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks)
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
|
|
||||||
|
#TODO jadechoghari
|
||||||
|
# this attention part should be reworked
|
||||||
|
breakpoint()
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
@@ -762,6 +904,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
return suffix_out
|
return suffix_out
|
||||||
|
|
||||||
|
# TODO: jadechoghri
|
||||||
|
# add subtask prediction loss
|
||||||
|
# add fast action prediction loss
|
||||||
suffix_out = self._apply_checkpoint(
|
suffix_out = self._apply_checkpoint(
|
||||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||||
)
|
)
|
||||||
@@ -895,15 +1040,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
return self.action_out_proj(suffix_out)
|
return self.action_out_proj(suffix_out)
|
||||||
|
|
||||||
|
|
||||||
class PI05Policy(PreTrainedPolicy):
|
class PI05FullPolicy(PreTrainedPolicy):
|
||||||
"""PI05 Policy for LeRobot."""
|
"""PI05 Policy for LeRobot."""
|
||||||
|
|
||||||
config_class = PI05Config
|
config_class = PI05FullConfig
|
||||||
name = "pi05"
|
name = "pi05_full"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PI05Config,
|
config: PI05FullConfig,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1245,12 +1390,15 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
"""
|
"""
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
||||||
|
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
|
||||||
|
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
|
action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"]
|
||||||
|
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
from lerobot.policies.pi05_full.modeling_pi05 import pad_vector
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
|
ActionTokenizerProcessorStep,
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
@@ -45,15 +46,16 @@ from lerobot.utils.constants import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
@ProcessorStepRegistry.register(name="pi05_full_prepare_state_tokenizer_processor_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||||
"""
|
"""
|
||||||
Processor step to prepare the state and tokenize the language input.
|
Processor step to prepare the state and tokenize the language input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_state_dim: int = 32
|
max_state_dim: int = 32
|
||||||
task_key: str = "task"
|
user_prompt_key: str = "user_prompt"
|
||||||
|
command_key: str = "task"
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
transition = transition.copy()
|
transition = transition.copy()
|
||||||
@@ -61,9 +63,12 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||||
if state is None:
|
if state is None:
|
||||||
raise ValueError("State is required for PI05")
|
raise ValueError("State is required for PI05")
|
||||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key)
|
||||||
if tasks is None:
|
if user_prompts is None:
|
||||||
raise ValueError("No task found in complementary data")
|
raise ValueError("No user prompts found in complementary data")
|
||||||
|
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key)
|
||||||
|
if commands is None:
|
||||||
|
raise ValueError("No commands found in complementary data")
|
||||||
|
|
||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
@@ -77,13 +82,24 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||||
|
|
||||||
full_prompts = []
|
full_prompts = []
|
||||||
for i, task in enumerate(tasks):
|
for i, user_prompt in enumerate(user_prompts):
|
||||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
|
||||||
state_str = " ".join(map(str, discretized_states[i]))
|
state_str = " ".join(map(str, discretized_states[i]))
|
||||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
full_prompt = f"User prompt: {cleaned_text}, State: {state_str};\n"
|
||||||
full_prompts.append(full_prompt)
|
full_prompts.append(full_prompt)
|
||||||
|
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts
|
||||||
|
|
||||||
|
# process commands
|
||||||
|
full_commands = []
|
||||||
|
for i, command in enumerate(commands):
|
||||||
|
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||||
|
full_command = f"Subtask: {cleaned_text};\n"
|
||||||
|
full_commands.append(full_command)
|
||||||
|
|
||||||
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands
|
||||||
|
|
||||||
|
# note: action tokens will be processed in the ActionTokenizerProcessorStep
|
||||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
return transition
|
return transition
|
||||||
@@ -97,8 +113,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def make_pi05_pre_post_processors(
|
def make_pi05_full_pre_post_processors(
|
||||||
config: PI05Config,
|
config: PI05FullConfig,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
@@ -140,13 +156,19 @@ def make_pi05_pre_post_processors(
|
|||||||
norm_map=config.normalization_mapping,
|
norm_map=config.normalization_mapping,
|
||||||
stats=dataset_stats,
|
stats=dataset_stats,
|
||||||
),
|
),
|
||||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
Pi05FullPrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||||
TokenizerProcessorStep(
|
TokenizerProcessorStep(
|
||||||
tokenizer_name="google/paligemma-3b-pt-224",
|
tokenizer_name=config.text_tokenizer_name,
|
||||||
max_length=config.tokenizer_max_length,
|
max_length=config.tokenizer_max_length,
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
),
|
),
|
||||||
|
ActionTokenizerProcessorStep(
|
||||||
|
action_tokenizer_name=config.action_tokenizer_name,
|
||||||
|
max_action_tokens=config.max_action_tokens,
|
||||||
|
fast_skip_tokens=config.fast_skip_tokens,
|
||||||
|
paligemma_tokenizer_name=config.text_tokenizer_name,
|
||||||
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -170,9 +170,11 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
|||||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||||
|
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||||
|
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||||
|
|
||||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key, **user_prompt_key, **subtask_key}
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ from lerobot.utils.constants import (
|
|||||||
ACTION_TOKENS,
|
ACTION_TOKENS,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_LANGUAGE_USER_PROMPT,
|
||||||
|
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||||
)
|
)
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -139,18 +142,44 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_user_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||||
|
"""
|
||||||
|
Extracts the user_prompt from the transition's complementary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transition: The environment transition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of user_prompt strings, or None if the user_prompt key is not found or the value is None.
|
||||||
|
"""
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
if complementary_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_prompt = complementary_data.get("user_prompt")
|
||||||
|
if user_prompt is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Standardize to a list of strings for the tokenizer
|
||||||
|
if isinstance(user_prompt, str):
|
||||||
|
return [user_prompt]
|
||||||
|
elif isinstance(user_prompt, list) and all(isinstance(t, str) for t in user_prompt):
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||||
"""
|
"""
|
||||||
Tokenizes the task description and adds it to the observation dictionary.
|
Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
|
||||||
|
|
||||||
This method retrieves the task, tokenizes it, moves the resulting tensors to the
|
This method retrieves the task and user_prompt, tokenizes them, moves the resulting tensors to the
|
||||||
same device as other data in the transition, and updates the observation.
|
same device as other data in the transition, and updates the observation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
observation: The original observation dictionary.
|
observation: The original observation dictionary.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated observation dictionary including token IDs and an attention mask.
|
The updated observation dictionary including token IDs and attention masks.
|
||||||
"""
|
"""
|
||||||
task = self.get_task(self.transition)
|
task = self.get_task(self.transition)
|
||||||
if task is None:
|
if task is None:
|
||||||
@@ -176,6 +205,22 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
|
# Tokenize user_prompt if available
|
||||||
|
user_prompt = self.get_user_prompt(self.transition)
|
||||||
|
if user_prompt is not None:
|
||||||
|
tokenized_user_prompt = self._tokenize_text(user_prompt)
|
||||||
|
|
||||||
|
# Move new tokenized tensors to the detected device
|
||||||
|
if target_device is not None:
|
||||||
|
tokenized_user_prompt = {
|
||||||
|
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in tokenized_user_prompt.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add tokenized user_prompt to the observation
|
||||||
|
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
|
||||||
|
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
return new_observation
|
return new_observation
|
||||||
|
|
||||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||||
@@ -274,6 +319,17 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add features for user_prompt tokens and attention mask if they don't already exist
|
||||||
|
if OBS_LANGUAGE_USER_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_TOKENS] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
|||||||
OBS_LANGUAGE = OBS_STR + ".language"
|
OBS_LANGUAGE = OBS_STR + ".language"
|
||||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||||
|
OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt"
|
||||||
|
OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens"
|
||||||
|
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask"
|
||||||
|
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
ACTION_PREFIX = ACTION + "."
|
ACTION_PREFIX = ACTION + "."
|
||||||
|
|||||||
Reference in New Issue
Block a user