feat(pi05): implement Classifier-Free Guidance (CFG) inference

Add dual-path denoising with configurable cfg_beta scale for language-
conditioned action generation. When cfg_beta > 1.0, VLM prefills both
conditioned and unconditional prompts, and action expert velocities are
interpolated via v = v_uncond + β*(v_cond - v_uncond).
This commit is contained in:
Khalil Meftah
2026-06-22 17:37:33 +02:00
parent 7d1e1b0357
commit 2d4be80425
7 changed files with 428 additions and 12 deletions
@@ -92,6 +92,12 @@ class PI05Config(PreTrainedConfig):
# the recipe YAML before prompt construction.
recipe_path: str | None = None
# Classifier-Free Guidance (CFG) scale for inference (Eq. 13 in RECAP paper).
# 1.0 = no guidance (default). >1.0 enables dual-path denoising where:
# v = v_uncond + cfg_beta * (v_cond - v_uncond)
# VLM runs twice (cond + uncond prompts), action expert runs 2x per step.
cfg_beta: float = 1.0
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
+141 -2
View File
@@ -52,6 +52,8 @@ from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
OBS_LANGUAGE_UNCOND_TOKENS,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -148,6 +150,20 @@ def clone_past_key_values(past_key_values):
)
def cat_past_key_values(kv_a, kv_b):
"""Concatenate two DynamicCaches along the batch dimension for batched CFG."""
return DynamicCache(
tuple(
(
torch.cat([ka, kb], dim=0),
torch.cat([va, vb], dim=0),
sw_a,
)
for (ka, va, sw_a), (kb, vb, _sw_b) in zip(kv_a, kv_b, strict=True)
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -797,9 +813,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
masks,
noise=None,
num_steps=None,
uncond_tokens=None,
uncond_masks=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
"""Do a full inference forward and compute the action.
When cfg_beta > 1.0 and uncond_tokens/uncond_masks are provided, performs
Classifier-Free Guidance: VLM runs twice (conditioned + unconditional), action
expert runs twice per denoising step, and velocities are interpolated via
v = v_uncond + cfg_beta * (v_cond - v_uncond).
"""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -815,6 +839,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
cfg_enabled = self.config.cfg_beta > 1.0 and uncond_tokens is not None and uncond_masks is not None
# Prefill VLM for conditioned prompt
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
@@ -830,6 +857,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_cache=True,
)
# Prefill VLM for unconditional prompt (CFG)
if cfg_enabled:
uncond_prefix_embs, uncond_prefix_pad_masks, uncond_prefix_att_masks = self.embed_prefix(
images, img_masks, uncond_tokens, uncond_masks
)
uncond_prefix_att_2d_masks = make_att_2d_masks(uncond_prefix_pad_masks, uncond_prefix_att_masks)
uncond_prefix_position_ids = torch.cumsum(uncond_prefix_pad_masks, dim=1) - 1
uncond_prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(uncond_prefix_att_2d_masks)
_, uncond_past_key_values = self.paligemma_with_expert.forward(
attention_mask=uncond_prefix_att_2d_masks_4d,
position_ids=uncond_prefix_position_ids,
past_key_values=None,
inputs_embeds=[uncond_prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
x_t = noise
@@ -838,6 +882,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
if cfg_enabled:
return self.denoise_step_cfg_batched(
cond_prefix_pad_masks=prefix_pad_masks,
cond_past_key_values=past_key_values,
uncond_prefix_pad_masks=uncond_prefix_pad_masks,
uncond_past_key_values=uncond_past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
@@ -907,6 +960,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)
def denoise_step_cfg_batched(
self,
cond_prefix_pad_masks,
cond_past_key_values,
uncond_prefix_pad_masks,
uncond_past_key_values,
x_t,
timestep,
):
"""Batched CFG denoising: runs cond + uncond in a single forward pass.
Concatenates cond and uncond inputs along the batch dimension, runs one
action expert forward (2x batch), then splits and applies CFG interpolation.
This is ~1.5x faster than two sequential denoise_step calls due to better
GPU utilization (inspired by Qwen2.5-Omni DiT / diffusers batched CFG).
"""
# Embed suffix once (same x_t and timestep for both branches)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
bsize = cond_prefix_pad_masks.shape[0]
suffix_len = suffix_pad_masks.shape[1]
cond_prefix_len = cond_prefix_pad_masks.shape[1]
uncond_prefix_len = uncond_prefix_pad_masks.shape[1]
# Build attention masks for cond branch
cond_prefix_2d = cond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, cond_prefix_len)
cond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
cond_full_att = torch.cat([cond_prefix_2d, cond_suffix_att_2d], dim=2)
cond_prefix_offsets = torch.sum(cond_prefix_pad_masks, dim=-1)[:, None]
cond_position_ids = cond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Build attention masks for uncond branch
uncond_prefix_2d = uncond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, uncond_prefix_len)
uncond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
uncond_full_att = torch.cat([uncond_prefix_2d, uncond_suffix_att_2d], dim=2)
uncond_prefix_offsets = torch.sum(uncond_prefix_pad_masks, dim=-1)[:, None]
uncond_position_ids = uncond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Concatenate on batch dim: [cond_batch; uncond_batch]
batched_full_att = torch.cat([cond_full_att, uncond_full_att], dim=0)
batched_full_att_4d = self._prepare_attention_masks_4d(batched_full_att)
batched_position_ids = torch.cat([cond_position_ids, uncond_position_ids], dim=0)
batched_suffix_embs = torch.cat([suffix_embs, suffix_embs], dim=0)
batched_adarms_cond = torch.cat([adarms_cond, adarms_cond], dim=0)
# Concatenate KV caches on batch dim
batched_past_kv = cat_past_key_values(
clone_past_key_values(cond_past_key_values),
clone_past_key_values(uncond_past_key_values),
)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
# Single forward pass for both branches
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=batched_full_att_4d,
position_ids=batched_position_ids,
past_key_values=batched_past_kv,
inputs_embeds=[None, batched_suffix_embs],
use_cache=False,
adarms_cond=[None, batched_adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_all = self.action_out_proj(suffix_out)
# Split: first half = cond, second half = uncond
v_cond, v_uncond = v_all.chunk(2, dim=0)
# CFG interpolation: v = v_uncond + beta * (v_cond - v_uncond)
return v_uncond + self.config.cfg_beta * (v_cond - v_uncond)
class PI05Policy(PreTrainedPolicy):
"""PI05 Policy for LeRobot."""
@@ -1243,8 +1370,20 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# CFG: pass unconditional tokens if available
uncond_tokens = batch.get(f"{OBS_LANGUAGE_UNCOND_TOKENS}")
uncond_masks = batch.get(f"{OBS_LANGUAGE_UNCOND_ATTENTION_MASK}")
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
actions = self.model.sample_actions(
images,
img_masks,
tokens,
masks,
uncond_tokens=uncond_tokens,
uncond_masks=uncond_masks,
**kwargs,
)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
+44 -4
View File
@@ -40,6 +40,8 @@ from lerobot.processor import (
)
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
OBS_LANGUAGE_UNCOND_TOKENS,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
@@ -57,6 +59,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
max_state_dim: int = 32
task_key: str = "task"
cfg_enabled: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
@@ -84,8 +87,25 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
# Build unconditional prompts for CFG (same state but original task without advantage)
if self.cfg_enabled:
base_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("base_task")
if base_tasks is None:
base_tasks = tasks
if isinstance(base_tasks, str):
base_tasks = [base_tasks] * len(tasks)
uncond_prompts = []
for i, base_task in enumerate(base_tasks):
cleaned_text = base_task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
uncond_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
uncond_prompts.append(uncond_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA]["uncond_task"] = uncond_prompts
return transition
def transform_features(
@@ -158,19 +178,39 @@ def make_pi05_pre_post_processors(
input_steps.append(RenderMessagesStep(recipe=recipe))
input_steps.append(RenderedMessagesToTaskStep())
cfg_enabled = config.cfg_beta > 1.0
input_steps.extend(
[
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
Pi05PrepareStateTokenizerProcessorStep(
max_state_dim=config.max_state_dim,
cfg_enabled=cfg_enabled,
),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
)
# Add unconditional prompt tokenizer for CFG inference
if cfg_enabled:
input_steps.append(
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
task_key="uncond_task",
output_tokens_key=OBS_LANGUAGE_UNCOND_TOKENS,
output_mask_key=OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
@@ -64,6 +64,8 @@ class RenderedMessagesToTaskStep(ComplementaryDataProcessorStep):
if user_parts:
task = complementary_data.get("task")
# Preserve the original task for CFG unconditional prompt
new_complementary_data["base_task"] = task
# Wrap in list if the original task was a list (batched)
joined = "\n".join(user_parts)
if isinstance(task, list):
+8 -6
View File
@@ -81,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
output_tokens_key: str = OBS_LANGUAGE_TOKENS
output_mask_key: str = OBS_LANGUAGE_ATTENTION_MASK
# Internal tokenizer instance (not part of the config)
input_tokenizer: Any = field(default=None, init=False, repr=False)
@@ -201,8 +203,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation = dict(observation)
# Add tokenized data to the observation
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
new_observation[self.output_tokens_key] = tokenized_prompt["input_ids"]
new_observation[self.output_mask_key] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
@@ -309,14 +311,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
The updated dictionary of policy features.
"""
# Add a feature for the token IDs if it doesn't already exist
if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
if self.output_tokens_key not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][self.output_tokens_key] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
# Add a feature for the attention mask if it doesn't already exist
if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
if self.output_mask_key not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][self.output_mask_key] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
+3
View File
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_UNCOND = OBS_STR + ".language_uncond"
OBS_LANGUAGE_UNCOND_TOKENS = OBS_LANGUAGE_UNCOND + ".tokens"
OBS_LANGUAGE_UNCOND_ATTENTION_MASK = OBS_LANGUAGE_UNCOND + ".attention_mask"
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
+224
View File
@@ -0,0 +1,224 @@
#!/usr/bin/env python
"""Tests for PI05 Classifier-Free Guidance (CFG) inference."""
import pytest
pytest.importorskip("transformers", reason="transformers is required for PI05")
import torch # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi05 import PI05Config, make_pi05_pre_post_processors # noqa: E402
from lerobot.processor.converters import create_transition # noqa: E402
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402
from lerobot.types import TransitionKey # noqa: E402
from lerobot.utils.constants import ( # noqa: E402
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
OBS_LANGUAGE_UNCOND_TOKENS,
)
class TestRenderedMessagesToTaskBaseTaskPreservation:
"""Tests that RenderedMessagesToTaskStep preserves base_task for CFG."""
def test_preserves_string_base_task(self):
transition = create_transition(
complementary_data={
"task": "pick up the cup",
"messages": [
{"role": "user", "content": "pick up the cup, Advantage: positive"},
],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["base_task"] == "pick up the cup"
assert data["task"] == "pick up the cup, Advantage: positive"
def test_preserves_list_base_task(self):
transition = create_transition(
complementary_data={
"task": ["task1", "task2"],
"messages": [
{"role": "user", "content": "rendered with advantage"},
],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["base_task"] == ["task1", "task2"]
def test_no_base_task_when_messages_absent(self):
transition = create_transition(complementary_data={"task": "pick up the cup"})
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "base_task" not in data
class TestPi05PrepareStateTokenizerCfg:
"""Tests for Pi05PrepareStateTokenizerProcessorStep with cfg_enabled."""
def _make_transition(self, task, base_task=None):
complementary_data = {"task": task}
if base_task is not None:
complementary_data["base_task"] = base_task
return create_transition(
observation={"observation.state": torch.zeros(1, 14)},
complementary_data=complementary_data,
)
def test_cfg_disabled_no_uncond_task(self):
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=False)
transition = self._make_transition(task=["pick up the cup, Advantage: positive"])
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "uncond_task" not in data
def test_cfg_enabled_produces_uncond_task_from_base(self):
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
transition = self._make_transition(
task=["pick up the cup, Advantage: positive"],
base_task=["pick up the cup"],
)
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "uncond_task" in data
assert len(data["uncond_task"]) == 1
# Unconditional prompt uses base_task (no advantage)
assert "Advantage" not in data["uncond_task"][0]
assert "pick up the cup" in data["uncond_task"][0]
assert "State:" in data["uncond_task"][0]
def test_cfg_enabled_falls_back_to_task_when_no_base(self):
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
transition = self._make_transition(task=["pick up the cup"])
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
# Falls back to using task itself as unconditional
assert "uncond_task" in data
assert "pick up the cup" in data["uncond_task"][0]
class TestCfgPipelineConstruction:
"""Tests that the processor pipeline is constructed correctly for CFG."""
def _make_config(self, cfg_beta=1.0, recipe_path=None):
config = PI05Config(
max_action_dim=7,
max_state_dim=14,
cfg_beta=cfg_beta,
recipe_path=recipe_path,
device="cpu",
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
return config
def _make_dataset_stats(self):
return {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"min": torch.zeros(14),
"max": torch.ones(14),
"q01": torch.zeros(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"min": torch.zeros(7),
"max": torch.ones(7),
"q01": torch.zeros(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
}
def test_no_uncond_tokenizer_when_cfg_disabled(self):
from lerobot.processor import TokenizerProcessorStep
config = self._make_config(cfg_beta=1.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
assert len(tokenizer_steps) == 1
def test_uncond_tokenizer_added_when_cfg_enabled(self):
from lerobot.processor import TokenizerProcessorStep
config = self._make_config(cfg_beta=2.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
assert len(tokenizer_steps) == 2
uncond_tokenizer = tokenizer_steps[1]
assert uncond_tokenizer.task_key == "uncond_task"
assert uncond_tokenizer.output_tokens_key == OBS_LANGUAGE_UNCOND_TOKENS
assert uncond_tokenizer.output_mask_key == OBS_LANGUAGE_UNCOND_ATTENTION_MASK
def test_cfg_pipeline_produces_both_token_sets(self):
config = self._make_config(cfg_beta=2.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
batch = {
"observation.state": torch.randn(14),
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_ATTENTION_MASK in processed
assert OBS_LANGUAGE_UNCOND_TOKENS in processed
assert OBS_LANGUAGE_UNCOND_ATTENTION_MASK in processed
# Both should be tensors with the same shape
assert processed[OBS_LANGUAGE_TOKENS].shape == processed[OBS_LANGUAGE_UNCOND_TOKENS].shape
assert (
processed[OBS_LANGUAGE_ATTENTION_MASK].shape
== processed[OBS_LANGUAGE_UNCOND_ATTENTION_MASK].shape
)
def test_cfg_beta_1_no_uncond_tokens_in_output(self):
config = self._make_config(cfg_beta=1.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
batch = {
"observation.state": torch.randn(14),
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_UNCOND_TOKENS not in processed