diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a593e5bcb..7286f764e 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -390,6 +391,13 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, PI05FullConfig): + from lerobot.policies.pi05_full.processor_pi05 import make_pi05_full_pre_post_processors + + processors = make_pi05_full_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) else: try: diff --git a/src/lerobot/policies/pi05_full/__init__.py b/src/lerobot/policies/pi05_full/__init__.py index 4f9a9de4a..5455cc1a8 100644 --- a/src/lerobot/policies/pi05_full/__init__.py +++ b/src/lerobot/policies/pi05_full/__init__.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration_pi05 import PI05Config -from .modeling_pi05 import PI05Policy -from .processor_pi05 import make_pi05_pre_post_processors +from .configuration_pi05 import PI05FullConfig +from .modeling_pi05 import PI05FullPolicy +from .processor_pi05 import make_pi05_full_pre_post_processors -__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] +__all__ = ["PI05FullConfig", "PI05FullPolicy", "make_pi05_full_pre_post_processors"] diff --git a/src/lerobot/policies/pi05_full/configuration_pi05.py b/src/lerobot/policies/pi05_full/configuration_pi05.py index b96e6d196..aa9c6c0bb 100644 --- a/src/lerobot/policies/pi05_full/configuration_pi05.py +++ b/src/lerobot/policies/pi05_full/configuration_pi05.py @@ -26,9 +26,9 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 -@PreTrainedConfig.register_subclass("pi05") +@PreTrainedConfig.register_subclass("pi05_full") @dataclass -class PI05Config(PreTrainedConfig): +class PI05FullConfig(PreTrainedConfig): paligemma_variant: str = "gemma_2b" action_expert_variant: str = "gemma_300m" dtype: str = "float32" # Options: "bfloat16", "float32" @@ -71,6 +71,11 @@ class PI05Config(PreTrainedConfig): } ) + action_tokenizer_name: str = "physical-intelligence/fast" + text_tokenizer_name: str = "google/paligemma-3b-pt-224" + max_action_tokens: int = 256 + fast_skip_tokens: int = 128 + # Training settings gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization compile_model: bool = False # Whether to use torch.compile for model optimization diff --git a/src/lerobot/policies/pi05_full/inference.py b/src/lerobot/policies/pi05_full/inference.py new file mode 100644 index 000000000..68aee5e42 --- /dev/null +++ b/src/lerobot/policies/pi05_full/inference.py @@ -0,0 +1,91 @@ +import torch +from huggingface_hub import HfApi + +import lerobot +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +# import make_pre_post_processors +from lerobot.policies.factory import make_pre_post_processors +from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.factory import make_policy, make_policy_config +from lerobot.configs.policies import PreTrainedConfig + +cfg = PreTrainedConfig.from_pretrained( + pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base", +) +cfg.dtype = "bfloat16" + +pre_processor, post_processor = make_pre_post_processors( + policy_cfg=cfg, + pretrained_path="/fsx/jade_choghari/models/pi05-base", +) + +delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]} + +dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps) + +# rename map --rename_map='{ +# "observation.images.side": "observation.images.base_0_rgb", +# "observation.images.up": "observation.images.left_wrist_0_rgb" +# }' +rename_map = { + "observation.images.side": "observation.images.base_0_rgb", + "observation.images.up": "observation.images.left_wrist_0_rgb" +} +policy = make_policy( + cfg=cfg, + ds_meta=dataset.meta, + rename_map=rename_map, +) + +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=4, + shuffle=True, +) + +batch = next(iter(dataloader)) +batch = pre_processor(batch) +policy.train() +# run inference +# action = policy.select_action(batch) +loss, loss_dict = policy.forward(batch) +breakpoint() +# import requests +# from PIL import Image +# from transformers import AutoProcessor +# model = policy.model.paligemma_with_expert.paligemma +# model = model.to(device="cuda", dtype=torch.bfloat16) +# model.eval() +# prompt = "Describe this image." +# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" +# image = Image.open(requests.get(url, stream=True).raw) +# processor = AutoProcessor.from_pretrained( +# "google/paligemma-3b-pt-224", +# ) +# inputs = processor(image, prompt, return_tensors="pt").to(model.device) +# print("generating...") +# output = model.generate( +# **inputs, +# max_new_tokens=50, +# use_cache=True, # default dynamic cache +# ) +# print(processor.decode(output[0], skip_special_tokens=True)) + + +# # other model +# from transformers import PaliGemmaForConditionalGeneration +# model = PaliGemmaForConditionalGeneration.from_pretrained( +# "google/paligemma2-3b-pt-224", +# torch_dtype=torch.bfloat16, +# device_map="auto", +# ) +# model.eval() +# print("generating...") +# output = model.generate( +# **inputs, +# max_new_tokens=100, +# use_cache=True, # default dynamic cache +# ) +# print("Model 2 output:") +# print(processor.decode(output[0], skip_special_tokens=True)) \ No newline at end of file diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index 11d8b4d68..7216c63b3 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -41,7 +41,7 @@ else: PaliGemmaForConditionalGeneration = None from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config +from lerobot.policies.pi05_full.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05FullConfig from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( @@ -49,6 +49,10 @@ from lerobot.utils.constants import ( OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OPENPI_ATTENTION_MASK_VALUE, + OBS_LANGUAGE_USER_PROMPT_TOKENS, + OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK, + ACTION_TOKENS, + ACTION_TOKEN_MASK, ) @@ -534,7 +538,7 @@ class PaliGemmaWithExpertModel( class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI05 PyTorch model.""" - def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): + def __init__(self, config: PI05FullConfig, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config self.rtc_processor = rtc_processor @@ -630,15 +634,106 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset return time.to(dtype=torch.float32, device=device) - + + def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize): + """Create custom 2D attention mask for the new attention pattern. + + Attention rules: + - Images + Language: bidirectional among themselves, don't attend to subtask or FAST + - Subtask: attend to images + language, causal among themselves, don't attend to FAST + - FAST: attend to images + language + subtask, causal among themselves + + Args: + att_mask_segments: List of (type, length) tuples + pad_masks: Padding masks [B, total_seq_len] + bsize: Batch size + + Returns: + att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len] + """ + total_len = sum(length for _, length in att_mask_segments) + device = pad_masks.device + + # start initializing attention mask as False (cannot attend) + att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device) + + # track positions for each segment + positions = [] + current_pos = 0 + for seg_type, seg_len in att_mask_segments: + positions.append((seg_type, current_pos, current_pos + seg_len)) + current_pos += seg_len + + # apply attention rules + for i, (query_type, query_start, query_end) in enumerate(positions): + for j, (key_type, key_start, key_end) in enumerate(positions): + # Images and Language can attend to each other bidirectionally + if query_type in ['image', 'language'] and key_type in ['image', 'language']: + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # Subtask tokens attend to images + language + elif query_type == 'subtask' and key_type in ['image', 'language']: + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # Subtask tokens attend causally to themselves + elif query_type == 'subtask' and key_type == 'subtask': + # create causal mask for subtask tokens + subtask_len = query_end - query_start + causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device)) + att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] + + # FAST tokens attend to images + language + subtask + elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']: + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # FAST tokens attend causally to themselves + elif query_type == 'fast' and key_type == 'fast': + fast_len = query_end - query_start + causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device)) + att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] + + # apply padding masks + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + + return att_2d_masks def embed_prefix( - self, images, img_masks, tokens, masks - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer.""" + self, + images, + img_masks, + tokens, + subtask_tokens, + masks, + subtask_masks, + fast_action_tokens=None, + fast_action_masks=None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer. + + Args: + images: List of image tensors + img_masks: List of image masks + tokens: Language instruction tokens + subtask_tokens: Subtask tokens to predict (can be None for inference) + masks: Attention masks for tokens + fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs + fast_action_masks: Padding masks for FAST action tokens (can be None) + + Returns: + embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)] + pad_masks: Padding masks + att_masks: Custom 2D attention mask implementing the required pattern + total_T_images: Total number of image tokens + num_subtask_embs: Number of subtask token embeddings + num_fast_embs: Number of FAST action token embeddings + """ embs = [] pad_masks = [] - att_masks = [] - + att_mask_segments = [] # Store info about each segment for custom mask creation + total_T_images = 0 + num_subtask_embs = 0 + num_fast_embs = 0 + # Process images for img, img_mask in zip(images, img_masks, strict=True): @@ -650,9 +745,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) - att_masks += [0] * num_img_embs - - # Process language tokens + att_mask_segments.append(('image', num_img_embs)) + total_T_images += num_img_embs + + # Process language instruction tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) lang_emb_dim = lang_emb.shape[-1] @@ -663,14 +759,57 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` pad_masks.append(masks) num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs + att_mask_segments.append(('language', num_lang_embs)) + + # process subtask tokens (these are predicted, so use causal masking) + if subtask_tokens is not None: + def subtask_embed_func(subtask_tokens): + subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens) + subtask_emb_dim = subtask_emb.shape[-1] + return subtask_emb * math.sqrt(subtask_emb_dim) + + subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens) + embs.append(subtask_emb) + + # create subtask pad masks (non-zero tokens are valid) + pad_masks.append(subtask_masks) + + num_subtask_embs = subtask_emb.shape[1] + att_mask_segments.append(('subtask', num_subtask_embs)) + + # Process FAST action tokens (discrete token IDs) + if fast_action_tokens is not None: + + def fast_action_embed_func(fast_action_tokens): + fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens) + fast_emb_dim = fast_emb.shape[-1] + return fast_emb * math.sqrt(fast_emb_dim) + + fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) + embs.append(fast_action_emb) + + num_fast_embs = fast_action_tokens.shape[1] + pad_masks.append(fast_action_masks) + att_mask_segments.append(("fast", num_fast_embs)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + # create custom 2D attention mask + # Attention rules: + # - Images + Language: bidirectional among themselves, don't attend to subtask or FAST + # - Subtask: attend to images + language, causal among themselves, don't attend to FAST + # - FAST: attend to images + language + subtask, causal among themselves + att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize) - bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + # # Optionally visualize the attention mask + # self.visualize_attention_mask( + # att_mask_segments=att_mask_segments, + # att_2d_masks=att_masks, + # save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png", + # batch_idx=0, + # max_display_tokens=512 # Limit display for very long sequences + # ) return embs, pad_masks, att_masks @@ -721,7 +860,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: + def forward(self, images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss.""" if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -733,9 +872,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + #TODO jadechoghari + # this attention part should be reworked + breakpoint() if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 @@ -762,6 +904,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) return suffix_out + # TODO: jadechoghri + # add subtask prediction loss + # add fast action prediction loss suffix_out = self._apply_checkpoint( forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) @@ -895,15 +1040,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return self.action_out_proj(suffix_out) -class PI05Policy(PreTrainedPolicy): +class PI05FullPolicy(PreTrainedPolicy): """PI05 Policy for LeRobot.""" - config_class = PI05Config - name = "pi05" + config_class = PI05FullConfig + name = "pi05_full" def __init__( self, - config: PI05Config, + config: PI05FullConfig, **kwargs, ): """ @@ -1245,12 +1390,15 @@ class PI05Policy(PreTrainedPolicy): """ # Prepare inputs images, img_masks = self._preprocess_images(batch) - tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"] + subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"] actions = self.prepare_action(batch) # Compute loss (no separate state needed for PI05) - losses = self.model.forward(images, img_masks, tokens, masks, actions) + losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] diff --git a/src/lerobot/policies/pi05_full/processor_pi05.py b/src/lerobot/policies/pi05_full/processor_pi05.py index e29bc4c23..cff9480f7 100644 --- a/src/lerobot/policies/pi05_full/processor_pi05.py +++ b/src/lerobot/policies/pi05_full/processor_pi05.py @@ -22,9 +22,10 @@ import numpy as np import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.pi05.configuration_pi05 import PI05Config -from lerobot.policies.pi05.modeling_pi05 import pad_vector +from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig +from lerobot.policies.pi05_full.modeling_pi05 import pad_vector from lerobot.processor import ( + ActionTokenizerProcessorStep, AddBatchDimensionProcessorStep, DeviceProcessorStep, NormalizerProcessorStep, @@ -45,15 +46,16 @@ from lerobot.utils.constants import ( ) -@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") +@ProcessorStepRegistry.register(name="pi05_full_prepare_state_tokenizer_processor_step") @dataclass -class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): +class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): """ Processor step to prepare the state and tokenize the language input. """ max_state_dim: int = 32 - task_key: str = "task" + user_prompt_key: str = "user_prompt" + command_key: str = "task" def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() @@ -61,9 +63,12 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) if state is None: raise ValueError("State is required for PI05") - tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) - if tasks is None: - raise ValueError("No task found in complementary data") + user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key) + if user_prompts is None: + raise ValueError("No user prompts found in complementary data") + commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key) + if commands is None: + raise ValueError("No commands found in complementary data") # TODO: check if this necessary state = deepcopy(state) @@ -77,13 +82,24 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 full_prompts = [] - for i, task in enumerate(tasks): - cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + for i, user_prompt in enumerate(user_prompts): + cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ") state_str = " ".join(map(str, discretized_states[i])) - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompt = f"User prompt: {cleaned_text}, State: {state_str};\n" full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts - transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # process commands + full_commands = [] + for i, command in enumerate(commands): + cleaned_text = command.strip().replace("_", " ").replace("\n", " ") + full_command = f"Subtask: {cleaned_text};\n" + full_commands.append(full_command) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands + + # note: action tokens will be processed in the ActionTokenizerProcessorStep # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) return transition @@ -97,8 +113,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): return features -def make_pi05_pre_post_processors( - config: PI05Config, +def make_pi05_full_pre_post_processors( + config: PI05FullConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], @@ -140,13 +156,19 @@ def make_pi05_pre_post_processors( norm_map=config.normalization_mapping, stats=dataset_stats, ), - Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + Pi05FullPrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), TokenizerProcessorStep( - tokenizer_name="google/paligemma-3b-pt-224", + tokenizer_name=config.text_tokenizer_name, max_length=config.tokenizer_max_length, padding_side="right", padding="max_length", ), + ActionTokenizerProcessorStep( + action_tokenizer_name=config.action_tokenizer_name, + max_action_tokens=config.max_action_tokens, + fast_skip_tokens=config.fast_skip_tokens, + paligemma_tokenizer_name=config.text_tokenizer_name, + ), DeviceProcessorStep(device=config.device), ] diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 4f9485fee..fd4833c28 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -170,9 +170,11 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: task_key = {"task": batch["task"]} if "task" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {} + subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} + return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key, **user_prompt_key, **subtask_key} def create_transition( diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 5cd1bebb0..eedd9ec4c 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -35,6 +35,9 @@ from lerobot.utils.constants import ( ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, + OBS_LANGUAGE_USER_PROMPT, + OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK, + OBS_LANGUAGE_USER_PROMPT_TOKENS, ) from lerobot.utils.import_utils import _transformers_available @@ -139,18 +142,44 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_user_prompt(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the user_prompt from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of user_prompt strings, or None if the user_prompt key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + user_prompt = complementary_data.get("user_prompt") + if user_prompt is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(user_prompt, str): + return [user_prompt] + elif isinstance(user_prompt, list) and all(isinstance(t, str) for t in user_prompt): + return user_prompt + + return None + def observation(self, observation: RobotObservation) -> RobotObservation: """ - Tokenizes the task description and adds it to the observation dictionary. + Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary. - This method retrieves the task, tokenizes it, moves the resulting tensors to the + This method retrieves the task and user_prompt, tokenizes them, moves the resulting tensors to the same device as other data in the transition, and updates the observation. Args: observation: The original observation dictionary. Returns: - The updated observation dictionary including token IDs and an attention mask. + The updated observation dictionary including token IDs and attention masks. """ task = self.get_task(self.transition) if task is None: @@ -176,6 +205,22 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + # Tokenize user_prompt if available + user_prompt = self.get_user_prompt(self.transition) + if user_prompt is not None: + tokenized_user_prompt = self._tokenize_text(user_prompt) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_user_prompt = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_user_prompt.items() + } + + # Add tokenized user_prompt to the observation + new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"] + new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: @@ -274,6 +319,17 @@ class TokenizerProcessorStep(ObservationProcessorStep): type=FeatureType.LANGUAGE, shape=(self.max_length,) ) + # Add features for user_prompt tokens and attention mask if they don't already exist + if OBS_LANGUAGE_USER_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_TOKENS] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + if OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + return features diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43a61b4f7..b52a1b80d 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" +OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt" +OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens" +OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask" ACTION = "action" ACTION_PREFIX = ACTION + "."