mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 12:47:18 +00:00
feat(pi05): implement Classifier-Free Guidance (CFG) inference
Add dual-path denoising with configurable cfg_beta scale for language- conditioned action generation. When cfg_beta > 1.0, VLM prefills both conditioned and unconditional prompts, and action expert velocities are interpolated via v = v_uncond + β*(v_cond - v_uncond).
This commit is contained in:
@@ -92,6 +92,12 @@ class PI05Config(PreTrainedConfig):
|
||||
# the recipe YAML before prompt construction.
|
||||
recipe_path: str | None = None
|
||||
|
||||
# Classifier-Free Guidance (CFG) scale for inference (Eq. 13 in RECAP paper).
|
||||
# 1.0 = no guidance (default). >1.0 enables dual-path denoising where:
|
||||
# v = v_uncond + cfg_beta * (v_cond - v_uncond)
|
||||
# VLM runs twice (cond + uncond prompts), action expert runs 2x per step.
|
||||
cfg_beta: float = 1.0
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
|
||||
@@ -52,6 +52,8 @@ from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -148,6 +150,20 @@ def clone_past_key_values(past_key_values):
|
||||
)
|
||||
|
||||
|
||||
def cat_past_key_values(kv_a, kv_b):
|
||||
"""Concatenate two DynamicCaches along the batch dimension for batched CFG."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(
|
||||
torch.cat([ka, kb], dim=0),
|
||||
torch.cat([va, vb], dim=0),
|
||||
sw_a,
|
||||
)
|
||||
for (ka, va, sw_a), (kb, vb, _sw_b) in zip(kv_a, kv_b, strict=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -797,9 +813,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
uncond_tokens=None,
|
||||
uncond_masks=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
"""Do a full inference forward and compute the action.
|
||||
|
||||
When cfg_beta > 1.0 and uncond_tokens/uncond_masks are provided, performs
|
||||
Classifier-Free Guidance: VLM runs twice (conditioned + unconditional), action
|
||||
expert runs twice per denoising step, and velocities are interpolated via
|
||||
v = v_uncond + cfg_beta * (v_cond - v_uncond).
|
||||
"""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
|
||||
@@ -815,6 +839,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
cfg_enabled = self.config.cfg_beta > 1.0 and uncond_tokens is not None and uncond_masks is not None
|
||||
|
||||
# Prefill VLM for conditioned prompt
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
@@ -830,6 +857,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Prefill VLM for unconditional prompt (CFG)
|
||||
if cfg_enabled:
|
||||
uncond_prefix_embs, uncond_prefix_pad_masks, uncond_prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, uncond_tokens, uncond_masks
|
||||
)
|
||||
uncond_prefix_att_2d_masks = make_att_2d_masks(uncond_prefix_pad_masks, uncond_prefix_att_masks)
|
||||
uncond_prefix_position_ids = torch.cumsum(uncond_prefix_pad_masks, dim=1) - 1
|
||||
uncond_prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(uncond_prefix_att_2d_masks)
|
||||
|
||||
_, uncond_past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=uncond_prefix_att_2d_masks_4d,
|
||||
position_ids=uncond_prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[uncond_prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
x_t = noise
|
||||
@@ -838,6 +882,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
if cfg_enabled:
|
||||
return self.denoise_step_cfg_batched(
|
||||
cond_prefix_pad_masks=prefix_pad_masks,
|
||||
cond_past_key_values=past_key_values,
|
||||
uncond_prefix_pad_masks=uncond_prefix_pad_masks,
|
||||
uncond_past_key_values=uncond_past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
@@ -907,6 +960,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
def denoise_step_cfg_batched(
|
||||
self,
|
||||
cond_prefix_pad_masks,
|
||||
cond_past_key_values,
|
||||
uncond_prefix_pad_masks,
|
||||
uncond_past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Batched CFG denoising: runs cond + uncond in a single forward pass.
|
||||
|
||||
Concatenates cond and uncond inputs along the batch dimension, runs one
|
||||
action expert forward (2x batch), then splits and applies CFG interpolation.
|
||||
This is ~1.5x faster than two sequential denoise_step calls due to better
|
||||
GPU utilization (inspired by Qwen2.5-Omni DiT / diffusers batched CFG).
|
||||
"""
|
||||
# Embed suffix once (same x_t and timestep for both branches)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
||||
|
||||
bsize = cond_prefix_pad_masks.shape[0]
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
cond_prefix_len = cond_prefix_pad_masks.shape[1]
|
||||
uncond_prefix_len = uncond_prefix_pad_masks.shape[1]
|
||||
|
||||
# Build attention masks for cond branch
|
||||
cond_prefix_2d = cond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, cond_prefix_len)
|
||||
cond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
cond_full_att = torch.cat([cond_prefix_2d, cond_suffix_att_2d], dim=2)
|
||||
cond_prefix_offsets = torch.sum(cond_prefix_pad_masks, dim=-1)[:, None]
|
||||
cond_position_ids = cond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Build attention masks for uncond branch
|
||||
uncond_prefix_2d = uncond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, uncond_prefix_len)
|
||||
uncond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
uncond_full_att = torch.cat([uncond_prefix_2d, uncond_suffix_att_2d], dim=2)
|
||||
uncond_prefix_offsets = torch.sum(uncond_prefix_pad_masks, dim=-1)[:, None]
|
||||
uncond_position_ids = uncond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Concatenate on batch dim: [cond_batch; uncond_batch]
|
||||
batched_full_att = torch.cat([cond_full_att, uncond_full_att], dim=0)
|
||||
batched_full_att_4d = self._prepare_attention_masks_4d(batched_full_att)
|
||||
batched_position_ids = torch.cat([cond_position_ids, uncond_position_ids], dim=0)
|
||||
batched_suffix_embs = torch.cat([suffix_embs, suffix_embs], dim=0)
|
||||
batched_adarms_cond = torch.cat([adarms_cond, adarms_cond], dim=0)
|
||||
|
||||
# Concatenate KV caches on batch dim
|
||||
batched_past_kv = cat_past_key_values(
|
||||
clone_past_key_values(cond_past_key_values),
|
||||
clone_past_key_values(uncond_past_key_values),
|
||||
)
|
||||
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
# Single forward pass for both branches
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=batched_full_att_4d,
|
||||
position_ids=batched_position_ids,
|
||||
past_key_values=batched_past_kv,
|
||||
inputs_embeds=[None, batched_suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, batched_adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_all = self.action_out_proj(suffix_out)
|
||||
|
||||
# Split: first half = cond, second half = uncond
|
||||
v_cond, v_uncond = v_all.chunk(2, dim=0)
|
||||
|
||||
# CFG interpolation: v = v_uncond + beta * (v_cond - v_uncond)
|
||||
return v_uncond + self.config.cfg_beta * (v_cond - v_uncond)
|
||||
|
||||
|
||||
class PI05Policy(PreTrainedPolicy):
|
||||
"""PI05 Policy for LeRobot."""
|
||||
@@ -1243,8 +1370,20 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# CFG: pass unconditional tokens if available
|
||||
uncond_tokens = batch.get(f"{OBS_LANGUAGE_UNCOND_TOKENS}")
|
||||
uncond_masks = batch.get(f"{OBS_LANGUAGE_UNCOND_ATTENTION_MASK}")
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
actions = self.model.sample_actions(
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
uncond_tokens=uncond_tokens,
|
||||
uncond_masks=uncond_masks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -40,6 +40,8 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
@@ -57,6 +59,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
cfg_enabled: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
@@ -84,8 +87,25 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
|
||||
# Build unconditional prompts for CFG (same state but original task without advantage)
|
||||
if self.cfg_enabled:
|
||||
base_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("base_task")
|
||||
if base_tasks is None:
|
||||
base_tasks = tasks
|
||||
|
||||
if isinstance(base_tasks, str):
|
||||
base_tasks = [base_tasks] * len(tasks)
|
||||
|
||||
uncond_prompts = []
|
||||
for i, base_task in enumerate(base_tasks):
|
||||
cleaned_text = base_task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
uncond_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
uncond_prompts.append(uncond_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["uncond_task"] = uncond_prompts
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
@@ -158,19 +178,39 @@ def make_pi05_pre_post_processors(
|
||||
input_steps.append(RenderMessagesStep(recipe=recipe))
|
||||
input_steps.append(RenderedMessagesToTaskStep())
|
||||
|
||||
cfg_enabled = config.cfg_beta > 1.0
|
||||
|
||||
input_steps.extend(
|
||||
[
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
Pi05PrepareStateTokenizerProcessorStep(
|
||||
max_state_dim=config.max_state_dim,
|
||||
cfg_enabled=cfg_enabled,
|
||||
),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
)
|
||||
|
||||
# Add unconditional prompt tokenizer for CFG inference
|
||||
if cfg_enabled:
|
||||
input_steps.append(
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
task_key="uncond_task",
|
||||
output_tokens_key=OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
output_mask_key=OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
)
|
||||
)
|
||||
|
||||
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
|
||||
@@ -64,6 +64,8 @@ class RenderedMessagesToTaskStep(ComplementaryDataProcessorStep):
|
||||
|
||||
if user_parts:
|
||||
task = complementary_data.get("task")
|
||||
# Preserve the original task for CFG unconditional prompt
|
||||
new_complementary_data["base_task"] = task
|
||||
# Wrap in list if the original task was a list (batched)
|
||||
joined = "\n".join(user_parts)
|
||||
if isinstance(task, list):
|
||||
|
||||
@@ -81,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
output_tokens_key: str = OBS_LANGUAGE_TOKENS
|
||||
output_mask_key: str = OBS_LANGUAGE_ATTENTION_MASK
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
@@ -201,8 +203,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
new_observation[self.output_tokens_key] = tokenized_prompt["input_ids"]
|
||||
new_observation[self.output_mask_key] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
@@ -309,14 +311,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Add a feature for the token IDs if it doesn't already exist
|
||||
if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
|
||||
if self.output_tokens_key not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][self.output_tokens_key] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add a feature for the attention mask if it doesn't already exist
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
if self.output_mask_key not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][self.output_mask_key] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_UNCOND = OBS_STR + ".language_uncond"
|
||||
OBS_LANGUAGE_UNCOND_TOKENS = OBS_LANGUAGE_UNCOND + ".tokens"
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK = OBS_LANGUAGE_UNCOND + ".attention_mask"
|
||||
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Tests for PI05 Classifier-Free Guidance (CFG) inference."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("transformers", reason="transformers is required for PI05")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Config, make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor.converters import create_transition # noqa: E402
|
||||
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402
|
||||
from lerobot.types import TransitionKey # noqa: E402
|
||||
from lerobot.utils.constants import ( # noqa: E402
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
)
|
||||
|
||||
|
||||
class TestRenderedMessagesToTaskBaseTaskPreservation:
|
||||
"""Tests that RenderedMessagesToTaskStep preserves base_task for CFG."""
|
||||
|
||||
def test_preserves_string_base_task(self):
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick up the cup",
|
||||
"messages": [
|
||||
{"role": "user", "content": "pick up the cup, Advantage: positive"},
|
||||
],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["base_task"] == "pick up the cup"
|
||||
assert data["task"] == "pick up the cup, Advantage: positive"
|
||||
|
||||
def test_preserves_list_base_task(self):
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": ["task1", "task2"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "rendered with advantage"},
|
||||
],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["base_task"] == ["task1", "task2"]
|
||||
|
||||
def test_no_base_task_when_messages_absent(self):
|
||||
transition = create_transition(complementary_data={"task": "pick up the cup"})
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "base_task" not in data
|
||||
|
||||
|
||||
class TestPi05PrepareStateTokenizerCfg:
|
||||
"""Tests for Pi05PrepareStateTokenizerProcessorStep with cfg_enabled."""
|
||||
|
||||
def _make_transition(self, task, base_task=None):
|
||||
complementary_data = {"task": task}
|
||||
if base_task is not None:
|
||||
complementary_data["base_task"] = base_task
|
||||
return create_transition(
|
||||
observation={"observation.state": torch.zeros(1, 14)},
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
|
||||
def test_cfg_disabled_no_uncond_task(self):
|
||||
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
|
||||
|
||||
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=False)
|
||||
transition = self._make_transition(task=["pick up the cup, Advantage: positive"])
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "uncond_task" not in data
|
||||
|
||||
def test_cfg_enabled_produces_uncond_task_from_base(self):
|
||||
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
|
||||
|
||||
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
|
||||
transition = self._make_transition(
|
||||
task=["pick up the cup, Advantage: positive"],
|
||||
base_task=["pick up the cup"],
|
||||
)
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "uncond_task" in data
|
||||
assert len(data["uncond_task"]) == 1
|
||||
# Unconditional prompt uses base_task (no advantage)
|
||||
assert "Advantage" not in data["uncond_task"][0]
|
||||
assert "pick up the cup" in data["uncond_task"][0]
|
||||
assert "State:" in data["uncond_task"][0]
|
||||
|
||||
def test_cfg_enabled_falls_back_to_task_when_no_base(self):
|
||||
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
|
||||
|
||||
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
|
||||
transition = self._make_transition(task=["pick up the cup"])
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Falls back to using task itself as unconditional
|
||||
assert "uncond_task" in data
|
||||
assert "pick up the cup" in data["uncond_task"][0]
|
||||
|
||||
|
||||
class TestCfgPipelineConstruction:
|
||||
"""Tests that the processor pipeline is constructed correctly for CFG."""
|
||||
|
||||
def _make_config(self, cfg_beta=1.0, recipe_path=None):
|
||||
config = PI05Config(
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
cfg_beta=cfg_beta,
|
||||
recipe_path=recipe_path,
|
||||
device="cpu",
|
||||
)
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
return config
|
||||
|
||||
def _make_dataset_stats(self):
|
||||
return {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"min": torch.zeros(14),
|
||||
"max": torch.ones(14),
|
||||
"q01": torch.zeros(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"min": torch.zeros(7),
|
||||
"max": torch.ones(7),
|
||||
"q01": torch.zeros(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
def test_no_uncond_tokenizer_when_cfg_disabled(self):
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
config = self._make_config(cfg_beta=1.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
|
||||
assert len(tokenizer_steps) == 1
|
||||
|
||||
def test_uncond_tokenizer_added_when_cfg_enabled(self):
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
config = self._make_config(cfg_beta=2.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
|
||||
assert len(tokenizer_steps) == 2
|
||||
|
||||
uncond_tokenizer = tokenizer_steps[1]
|
||||
assert uncond_tokenizer.task_key == "uncond_task"
|
||||
assert uncond_tokenizer.output_tokens_key == OBS_LANGUAGE_UNCOND_TOKENS
|
||||
assert uncond_tokenizer.output_mask_key == OBS_LANGUAGE_UNCOND_ATTENTION_MASK
|
||||
|
||||
def test_cfg_pipeline_produces_both_token_sets(self):
|
||||
config = self._make_config(cfg_beta=2.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(14),
|
||||
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
processed = preprocessor(batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert OBS_LANGUAGE_UNCOND_TOKENS in processed
|
||||
assert OBS_LANGUAGE_UNCOND_ATTENTION_MASK in processed
|
||||
|
||||
# Both should be tensors with the same shape
|
||||
assert processed[OBS_LANGUAGE_TOKENS].shape == processed[OBS_LANGUAGE_UNCOND_TOKENS].shape
|
||||
assert (
|
||||
processed[OBS_LANGUAGE_ATTENTION_MASK].shape
|
||||
== processed[OBS_LANGUAGE_UNCOND_ATTENTION_MASK].shape
|
||||
)
|
||||
|
||||
def test_cfg_beta_1_no_uncond_tokens_in_output(self):
|
||||
config = self._make_config(cfg_beta=1.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(14),
|
||||
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
processed = preprocessor(batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_UNCOND_TOKENS not in processed
|
||||
Reference in New Issue
Block a user