From 2d4be804258944f9581b38154c500d97db5cfbbf Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 22 Jun 2026 17:37:33 +0200 Subject: [PATCH] feat(pi05): implement Classifier-Free Guidance (CFG) inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add dual-path denoising with configurable cfg_beta scale for language- conditioned action generation. When cfg_beta > 1.0, VLM prefills both conditioned and unconditional prompts, and action expert velocities are interpolated via v = v_uncond + β*(v_cond - v_uncond). --- .../policies/pi05/configuration_pi05.py | 6 + src/lerobot/policies/pi05/modeling_pi05.py | 143 ++++++++++- src/lerobot/policies/pi05/processor_pi05.py | 48 +++- .../processor/rendered_messages_to_task.py | 2 + src/lerobot/processor/tokenizer_processor.py | 14 +- src/lerobot/utils/constants.py | 3 + tests/policies/pi0_pi05/test_pi05_cfg.py | 224 ++++++++++++++++++ 7 files changed, 428 insertions(+), 12 deletions(-) create mode 100644 tests/policies/pi0_pi05/test_pi05_cfg.py diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 06df47b87..48a9ad37c 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -92,6 +92,12 @@ class PI05Config(PreTrainedConfig): # the recipe YAML before prompt construction. recipe_path: str | None = None + # Classifier-Free Guidance (CFG) scale for inference (Eq. 13 in RECAP paper). + # 1.0 = no guidance (default). >1.0 enables dual-path denoising where: + # v = v_uncond + cfg_beta * (v_cond - v_uncond) + # VLM runs twice (cond + uncond prompts), action expert runs 2x per step. + cfg_beta: float = 1.0 + # Optimizer settings: see openpi `AdamW` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` optimizer_betas: tuple[float, float] = (0.9, 0.95) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index aabd04c6f..a112c1e39 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -52,6 +52,8 @@ from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, + OBS_LANGUAGE_UNCOND_ATTENTION_MASK, + OBS_LANGUAGE_UNCOND_TOKENS, OPENPI_ATTENTION_MASK_VALUE, ) @@ -148,6 +150,20 @@ def clone_past_key_values(past_key_values): ) +def cat_past_key_values(kv_a, kv_b): + """Concatenate two DynamicCaches along the batch dimension for batched CFG.""" + return DynamicCache( + tuple( + ( + torch.cat([ka, kb], dim=0), + torch.cat([va, vb], dim=0), + sw_a, + ) + for (ka, va, sw_a), (kb, vb, _sw_b) in zip(kv_a, kv_b, strict=True) + ) + ) + + def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. @@ -797,9 +813,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` masks, noise=None, num_steps=None, + uncond_tokens=None, + uncond_masks=None, **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: - """Do a full inference forward and compute the action.""" + """Do a full inference forward and compute the action. + + When cfg_beta > 1.0 and uncond_tokens/uncond_masks are provided, performs + Classifier-Free Guidance: VLM runs twice (conditioned + unconditional), action + expert runs twice per denoising step, and velocities are interpolated via + v = v_uncond + cfg_beta * (v_cond - v_uncond). + """ if num_steps is None: num_steps = self.config.num_inference_steps @@ -815,6 +839,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) + cfg_enabled = self.config.cfg_beta > 1.0 and uncond_tokens is not None and uncond_masks is not None + + # Prefill VLM for conditioned prompt prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -830,6 +857,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` use_cache=True, ) + # Prefill VLM for unconditional prompt (CFG) + if cfg_enabled: + uncond_prefix_embs, uncond_prefix_pad_masks, uncond_prefix_att_masks = self.embed_prefix( + images, img_masks, uncond_tokens, uncond_masks + ) + uncond_prefix_att_2d_masks = make_att_2d_masks(uncond_prefix_pad_masks, uncond_prefix_att_masks) + uncond_prefix_position_ids = torch.cumsum(uncond_prefix_pad_masks, dim=1) - 1 + uncond_prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(uncond_prefix_att_2d_masks) + + _, uncond_past_key_values = self.paligemma_with_expert.forward( + attention_mask=uncond_prefix_att_2d_masks_4d, + position_ids=uncond_prefix_position_ids, + past_key_values=None, + inputs_embeds=[uncond_prefix_embs, None], + use_cache=True, + ) + dt = -1.0 / num_steps x_t = noise @@ -838,6 +882,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + if cfg_enabled: + return self.denoise_step_cfg_batched( + cond_prefix_pad_masks=prefix_pad_masks, + cond_past_key_values=past_key_values, + uncond_prefix_pad_masks=uncond_prefix_pad_masks, + uncond_past_key_values=uncond_past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) return self.denoise_step( prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, @@ -907,6 +960,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_out = suffix_out.to(dtype=torch.float32) return self.action_out_proj(suffix_out) + def denoise_step_cfg_batched( + self, + cond_prefix_pad_masks, + cond_past_key_values, + uncond_prefix_pad_masks, + uncond_past_key_values, + x_t, + timestep, + ): + """Batched CFG denoising: runs cond + uncond in a single forward pass. + + Concatenates cond and uncond inputs along the batch dimension, runs one + action expert forward (2x batch), then splits and applies CFG interpolation. + This is ~1.5x faster than two sequential denoise_step calls due to better + GPU utilization (inspired by Qwen2.5-Omni DiT / diffusers batched CFG). + """ + # Embed suffix once (same x_t and timestep for both branches) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) + + bsize = cond_prefix_pad_masks.shape[0] + suffix_len = suffix_pad_masks.shape[1] + cond_prefix_len = cond_prefix_pad_masks.shape[1] + uncond_prefix_len = uncond_prefix_pad_masks.shape[1] + + # Build attention masks for cond branch + cond_prefix_2d = cond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, cond_prefix_len) + cond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + cond_full_att = torch.cat([cond_prefix_2d, cond_suffix_att_2d], dim=2) + cond_prefix_offsets = torch.sum(cond_prefix_pad_masks, dim=-1)[:, None] + cond_position_ids = cond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + # Build attention masks for uncond branch + uncond_prefix_2d = uncond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, uncond_prefix_len) + uncond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + uncond_full_att = torch.cat([uncond_prefix_2d, uncond_suffix_att_2d], dim=2) + uncond_prefix_offsets = torch.sum(uncond_prefix_pad_masks, dim=-1)[:, None] + uncond_position_ids = uncond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + # Concatenate on batch dim: [cond_batch; uncond_batch] + batched_full_att = torch.cat([cond_full_att, uncond_full_att], dim=0) + batched_full_att_4d = self._prepare_attention_masks_4d(batched_full_att) + batched_position_ids = torch.cat([cond_position_ids, uncond_position_ids], dim=0) + batched_suffix_embs = torch.cat([suffix_embs, suffix_embs], dim=0) + batched_adarms_cond = torch.cat([adarms_cond, adarms_cond], dim=0) + + # Concatenate KV caches on batch dim + batched_past_kv = cat_past_key_values( + clone_past_key_values(cond_past_key_values), + clone_past_key_values(uncond_past_key_values), + ) + + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + # Single forward pass for both branches + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=batched_full_att_4d, + position_ids=batched_position_ids, + past_key_values=batched_past_kv, + inputs_embeds=[None, batched_suffix_embs], + use_cache=False, + adarms_cond=[None, batched_adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + v_all = self.action_out_proj(suffix_out) + + # Split: first half = cond, second half = uncond + v_cond, v_uncond = v_all.chunk(2, dim=0) + + # CFG interpolation: v = v_uncond + beta * (v_cond - v_uncond) + return v_uncond + self.config.cfg_beta * (v_cond - v_uncond) + class PI05Policy(PreTrainedPolicy): """PI05 Policy for LeRobot.""" @@ -1243,8 +1370,20 @@ class PI05Policy(PreTrainedPolicy): images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + # CFG: pass unconditional tokens if available + uncond_tokens = batch.get(f"{OBS_LANGUAGE_UNCOND_TOKENS}") + uncond_masks = batch.get(f"{OBS_LANGUAGE_UNCOND_ATTENTION_MASK}") + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) - actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + actions = self.model.sample_actions( + images, + img_masks, + tokens, + masks, + uncond_tokens=uncond_tokens, + uncond_masks=uncond_masks, + **kwargs, + ) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index df5b932e0..17aead0f4 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -40,6 +40,8 @@ from lerobot.processor import ( ) from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( + OBS_LANGUAGE_UNCOND_ATTENTION_MASK, + OBS_LANGUAGE_UNCOND_TOKENS, OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, @@ -57,6 +59,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): max_state_dim: int = 32 task_key: str = "task" + cfg_enabled: bool = False def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() @@ -84,8 +87,25 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): full_prompts.append(full_prompt) transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts - # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) - # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + + # Build unconditional prompts for CFG (same state but original task without advantage) + if self.cfg_enabled: + base_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("base_task") + if base_tasks is None: + base_tasks = tasks + + if isinstance(base_tasks, str): + base_tasks = [base_tasks] * len(tasks) + + uncond_prompts = [] + for i, base_task in enumerate(base_tasks): + cleaned_text = base_task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + uncond_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + uncond_prompts.append(uncond_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA]["uncond_task"] = uncond_prompts + return transition def transform_features( @@ -158,19 +178,39 @@ def make_pi05_pre_post_processors( input_steps.append(RenderMessagesStep(recipe=recipe)) input_steps.append(RenderedMessagesToTaskStep()) + cfg_enabled = config.cfg_beta > 1.0 + input_steps.extend( [ - Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + Pi05PrepareStateTokenizerProcessorStep( + max_state_dim=config.max_state_dim, + cfg_enabled=cfg_enabled, + ), TokenizerProcessorStep( tokenizer_name="google/paligemma-3b-pt-224", max_length=config.tokenizer_max_length, padding_side="right", padding="max_length", ), - DeviceProcessorStep(device=config.device), ] ) + # Add unconditional prompt tokenizer for CFG inference + if cfg_enabled: + input_steps.append( + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + task_key="uncond_task", + output_tokens_key=OBS_LANGUAGE_UNCOND_TOKENS, + output_mask_key=OBS_LANGUAGE_UNCOND_ATTENTION_MASK, + ) + ) + + input_steps.append(DeviceProcessorStep(device=config.device)) + output_steps: list[ProcessorStep] = [ UnnormalizerProcessorStep( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats diff --git a/src/lerobot/processor/rendered_messages_to_task.py b/src/lerobot/processor/rendered_messages_to_task.py index c4cce25bb..39ea0448f 100644 --- a/src/lerobot/processor/rendered_messages_to_task.py +++ b/src/lerobot/processor/rendered_messages_to_task.py @@ -64,6 +64,8 @@ class RenderedMessagesToTaskStep(ComplementaryDataProcessorStep): if user_parts: task = complementary_data.get("task") + # Preserve the original task for CFG unconditional prompt + new_complementary_data["base_task"] = task # Wrap in list if the original task was a list (batched) joined = "\n".join(user_parts) if isinstance(task, list): diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index a808e6127..0d0589a66 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -81,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): padding_side: str = "right" padding: str = "max_length" truncation: bool = True + output_tokens_key: str = OBS_LANGUAGE_TOKENS + output_mask_key: str = OBS_LANGUAGE_ATTENTION_MASK # Internal tokenizer instance (not part of the config) input_tokenizer: Any = field(default=None, init=False, repr=False) @@ -201,8 +203,8 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation = dict(observation) # Add tokenized data to the observation - new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] - new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + new_observation[self.output_tokens_key] = tokenized_prompt["input_ids"] + new_observation[self.output_mask_key] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) # Tokenize subtask if available subtask = self.get_subtask(self.transition) @@ -309,14 +311,14 @@ class TokenizerProcessorStep(ObservationProcessorStep): The updated dictionary of policy features. """ # Add a feature for the token IDs if it doesn't already exist - if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]: - features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature( + if self.output_tokens_key not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][self.output_tokens_key] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) ) # Add a feature for the attention mask if it doesn't already exist - if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: - features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature( + if self.output_mask_key not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][self.output_mask_key] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) ) diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 482394ff6..f99845d0f 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" +OBS_LANGUAGE_UNCOND = OBS_STR + ".language_uncond" +OBS_LANGUAGE_UNCOND_TOKENS = OBS_LANGUAGE_UNCOND + ".tokens" +OBS_LANGUAGE_UNCOND_ATTENTION_MASK = OBS_LANGUAGE_UNCOND + ".attention_mask" OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask" OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens" OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask" diff --git a/tests/policies/pi0_pi05/test_pi05_cfg.py b/tests/policies/pi0_pi05/test_pi05_cfg.py new file mode 100644 index 000000000..a5bb4e8b5 --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi05_cfg.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python + +"""Tests for PI05 Classifier-Free Guidance (CFG) inference.""" + +import pytest + +pytest.importorskip("transformers", reason="transformers is required for PI05") + +import torch # noqa: E402 + +from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.pi05 import PI05Config, make_pi05_pre_post_processors # noqa: E402 +from lerobot.processor.converters import create_transition # noqa: E402 +from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402 +from lerobot.types import TransitionKey # noqa: E402 +from lerobot.utils.constants import ( # noqa: E402 + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_LANGUAGE_UNCOND_ATTENTION_MASK, + OBS_LANGUAGE_UNCOND_TOKENS, +) + + +class TestRenderedMessagesToTaskBaseTaskPreservation: + """Tests that RenderedMessagesToTaskStep preserves base_task for CFG.""" + + def test_preserves_string_base_task(self): + transition = create_transition( + complementary_data={ + "task": "pick up the cup", + "messages": [ + {"role": "user", "content": "pick up the cup, Advantage: positive"}, + ], + } + ) + step = RenderedMessagesToTaskStep() + out = step(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert data["base_task"] == "pick up the cup" + assert data["task"] == "pick up the cup, Advantage: positive" + + def test_preserves_list_base_task(self): + transition = create_transition( + complementary_data={ + "task": ["task1", "task2"], + "messages": [ + {"role": "user", "content": "rendered with advantage"}, + ], + } + ) + step = RenderedMessagesToTaskStep() + out = step(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert data["base_task"] == ["task1", "task2"] + + def test_no_base_task_when_messages_absent(self): + transition = create_transition(complementary_data={"task": "pick up the cup"}) + step = RenderedMessagesToTaskStep() + out = step(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert "base_task" not in data + + +class TestPi05PrepareStateTokenizerCfg: + """Tests for Pi05PrepareStateTokenizerProcessorStep with cfg_enabled.""" + + def _make_transition(self, task, base_task=None): + complementary_data = {"task": task} + if base_task is not None: + complementary_data["base_task"] = base_task + return create_transition( + observation={"observation.state": torch.zeros(1, 14)}, + complementary_data=complementary_data, + ) + + def test_cfg_disabled_no_uncond_task(self): + from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep + + step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=False) + transition = self._make_transition(task=["pick up the cup, Advantage: positive"]) + out = step(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert "uncond_task" not in data + + def test_cfg_enabled_produces_uncond_task_from_base(self): + from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep + + step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True) + transition = self._make_transition( + task=["pick up the cup, Advantage: positive"], + base_task=["pick up the cup"], + ) + out = step(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert "uncond_task" in data + assert len(data["uncond_task"]) == 1 + # Unconditional prompt uses base_task (no advantage) + assert "Advantage" not in data["uncond_task"][0] + assert "pick up the cup" in data["uncond_task"][0] + assert "State:" in data["uncond_task"][0] + + def test_cfg_enabled_falls_back_to_task_when_no_base(self): + from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep + + step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True) + transition = self._make_transition(task=["pick up the cup"]) + out = step(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + # Falls back to using task itself as unconditional + assert "uncond_task" in data + assert "pick up the cup" in data["uncond_task"][0] + + +class TestCfgPipelineConstruction: + """Tests that the processor pipeline is constructed correctly for CFG.""" + + def _make_config(self, cfg_beta=1.0, recipe_path=None): + config = PI05Config( + max_action_dim=7, + max_state_dim=14, + cfg_beta=cfg_beta, + recipe_path=recipe_path, + device="cpu", + ) + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + return config + + def _make_dataset_stats(self): + return { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "min": torch.zeros(14), + "max": torch.ones(14), + "q01": torch.zeros(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "min": torch.zeros(7), + "max": torch.ones(7), + "q01": torch.zeros(7), + "q99": torch.ones(7), + }, + "observation.images.base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + } + + def test_no_uncond_tokenizer_when_cfg_disabled(self): + from lerobot.processor import TokenizerProcessorStep + + config = self._make_config(cfg_beta=1.0) + preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats()) + + tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)] + assert len(tokenizer_steps) == 1 + + def test_uncond_tokenizer_added_when_cfg_enabled(self): + from lerobot.processor import TokenizerProcessorStep + + config = self._make_config(cfg_beta=2.0) + preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats()) + + tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)] + assert len(tokenizer_steps) == 2 + + uncond_tokenizer = tokenizer_steps[1] + assert uncond_tokenizer.task_key == "uncond_task" + assert uncond_tokenizer.output_tokens_key == OBS_LANGUAGE_UNCOND_TOKENS + assert uncond_tokenizer.output_mask_key == OBS_LANGUAGE_UNCOND_ATTENTION_MASK + + def test_cfg_pipeline_produces_both_token_sets(self): + config = self._make_config(cfg_beta=2.0) + preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats()) + + batch = { + "observation.state": torch.randn(14), + "observation.images.base_0_rgb": torch.rand(3, 224, 224), + "task": "pick up the cup", + } + processed = preprocessor(batch) + + assert OBS_LANGUAGE_TOKENS in processed + assert OBS_LANGUAGE_ATTENTION_MASK in processed + assert OBS_LANGUAGE_UNCOND_TOKENS in processed + assert OBS_LANGUAGE_UNCOND_ATTENTION_MASK in processed + + # Both should be tensors with the same shape + assert processed[OBS_LANGUAGE_TOKENS].shape == processed[OBS_LANGUAGE_UNCOND_TOKENS].shape + assert ( + processed[OBS_LANGUAGE_ATTENTION_MASK].shape + == processed[OBS_LANGUAGE_UNCOND_ATTENTION_MASK].shape + ) + + def test_cfg_beta_1_no_uncond_tokens_in_output(self): + config = self._make_config(cfg_beta=1.0) + preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats()) + + batch = { + "observation.state": torch.randn(14), + "observation.images.base_0_rgb": torch.rand(3, 224, 224), + "task": "pick up the cup", + } + processed = preprocessor(batch) + + assert OBS_LANGUAGE_TOKENS in processed + assert OBS_LANGUAGE_UNCOND_TOKENS not in processed