Compare commits

...

9 Commits

Author SHA1 Message Date
Khalil Meftah 2d4be80425 feat(pi05): implement Classifier-Free Guidance (CFG) inference
Add dual-path denoising with configurable cfg_beta scale for language-
conditioned action generation. When cfg_beta > 1.0, VLM prefills both
conditioned and unconditional prompts, and action expert velocities are
interpolated via v = v_uncond + β*(v_cond - v_uncond).
2026-06-22 17:37:33 +02:00
Khalil Meftah 7d1e1b0357 feat(pi05): integrate RenderMessagesStep for advantage conditioning
Add RenderedMessagesToTaskStep adapter that bridges recipe-rendered chat
messages back into PI05's task-string prompt format. When recipe_path is
set on PI05Config, the preprocessor inserts RenderMessagesStep + adapter
before prompt construction, enabling RECAP advantage text to flow
end-to-end through the recipe YAML system.
2026-06-22 15:55:39 +02:00
Khalil Meftah 0d2ba54385 feat(rollout): add episode success labeling to DAgger strategy 2026-06-22 15:08:05 +02:00
Khalil Meftah 4b779b1e99 feat(recap): add advantage conditioning recipe YAMLs 2026-06-22 14:39:45 +02:00
Khalil Meftah ea908c0672 feat(recap): add advantage scoring annotation module
Implement the RECAP advantage scoring module as a new phase in
lerobot-annotate. Uses a frozen distributional VF to compute per-frame
advantages, binarizes into positive/negative indicators with per-task
threshold, and writes style=advantage persistent rows for policy
conditioning. Skips VF inference on intervention frames as an optimization.
2026-06-22 14:01:58 +02:00
Khalil Meftah e5c94c732f feat(recap): add lerobot-compute-returns script to compute MC returns 2026-06-22 12:17:37 +02:00
Khalil Meftah c18b8277f1 Merge branch 'main' into feat/add-recap
# Conflicts:
#	uv.lock
2026-06-18 17:14:59 +02:00
Khalil Meftah fa3eb9fce3 test(rewards): add unit tests for distributional value function model 2026-06-10 16:07:43 +02:00
Khalil Meftah 500c91ba92 feat(rewards): introduce distributional value function model
- Added a new distributional value function (DistributionalVF) model for RECAP, including its configuration, modeling, and processor components.
- Updated the rewards factory to support the new model type.
- Updated  to include the new model in the dependencies.
2026-06-10 15:24:50 +02:00
37 changed files with 4476 additions and 374 deletions
+3
View File
@@ -220,6 +220,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
recap = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -317,6 +318,7 @@ all = [
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[recap]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
@@ -340,6 +342,7 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
lerobot-compute-returns="lerobot.scripts.lerobot_compute_returns:main"
# ---------------- Tool Configurations ----------------
@@ -169,6 +169,43 @@ class ExecutorConfig:
episode_parallelism: int = 16
@dataclass
class AdvantageConfig:
"""``advantage`` module: RECAP advantage scoring via frozen value function."""
enabled: bool = True
# Path or Hub repo ID of the trained distributional value function checkpoint.
value_function_path: str = ""
# Device to run the value function on.
device: str = "cuda"
# N-step lookahead for advantage estimation.
# None = MC (N=T): A_t = R_t - V(s_t), using mc_return from dataset.
# 50 = fine-tuning mode: A_t = Σ r_{t:t+N} + V(s_{t+N}) - V(s_t).
n_step: int | None = None
# Per-task percentile for binarization threshold ε_.
# Actions with advantage > ε_ get I_t = True (positive).
threshold_percentile: float = 0.3
# Fraction of frames to randomly omit advantage labels (enables CFG).
dropout_rate: float = 0.3
# Force I_t = True for frames marked as human interventions.
force_positive_on_intervention: bool = True
# Column name in dataset for intervention flag.
intervention_key: str = "intervention"
# Column name for pre-computed MC returns (from lerobot-compute-returns).
mc_return_key: str = "mc_return"
# Batch size for value function inference.
batch_size: int = 32
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
@@ -190,6 +227,7 @@ class AnnotationPipelineConfig:
plan: PlanConfig = field(default_factory=PlanConfig)
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
vqa: VqaConfig = field(default_factory=VqaConfig)
advantage: AdvantageConfig = field(default_factory=AdvantageConfig)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
@@ -15,20 +15,24 @@
# limitations under the License.
"""In-process executor that runs the annotation phases.
The executor runs **six phases** in dependency order:
The executor runs **seven phases** in dependency order:
phase 1: ``plan`` module (plan + subtasks + memory)
phase 2: ``interjections`` module (interjections + speech)
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: ``vqa`` module (VQA)
phase 5: validator
phase 6: writer
phase 5: ``advantage`` module (advantage scoring via frozen VF)
phase 6: validator
phase 7: writer
Phase 3 is why the ``plan`` module must be re-entered after the
``interjections`` module — to refresh ``plan`` rows at interjection
timestamps.
Phase 5 (advantage) does not depend on the VLM modules, it uses a frozen
distributional value function to compute per-frame advantage indicators.
Distributed execution is provided by Hugging Face Jobs (see
``examples/annotations/run_hf_job.py``); the runner inside the job
invokes ``lerobot-annotate`` which uses this in-process executor.
@@ -74,7 +78,7 @@ class PipelineRunSummary:
@dataclass
class Executor:
"""Run all six phases over a dataset root in-process.
"""Run all seven phases over a dataset root in-process.
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
(a thread pool); cluster-level concurrency comes from running this
@@ -86,6 +90,7 @@ class Executor:
plan: Any # PlanSubtasksMemoryModule
interjections: Any # InterjectionsAndSpeechModule
vqa: Any # GeneralVqaModule
advantage: Any # AdvantageModule
writer: LanguageColumnsWriter
validator: StagingValidator
@@ -112,6 +117,8 @@ class Executor:
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: ``vqa`` module (VQA)
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
# Phase 5: ``advantage`` module (advantage scoring via frozen VF)
phases.append(self._run_module_phase("advantage", records, staging_dir, self.advantage))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir)
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .advantage import AdvantageModule
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"AdvantageModule",
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
@@ -0,0 +1,263 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Advantage scoring module for RECAP.
Computes per-frame advantage values using a frozen distributional value function,
binarizes them into improvement indicators (I_t), and emits ``style="advantage"``
persistent rows for policy conditioning.
Paper reference: pi*0.6, Section IV-B and Appendix F.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from ..config import AdvantageConfig
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class AdvantageModule:
"""Compute advantage indicators and emit persistent annotation rows.
The module loads a frozen distributional value function and scores each
frame in an episode. Advantages are binarized into ``positive``/``negative``
indicators using a per-task threshold, then written as ``style="advantage"``
persistent rows into the staging area.
Requires ``mc_return`` column in the dataset (from lerobot-compute-returns).
"""
config: AdvantageConfig
_model: Any = field(default=None, init=False, repr=False)
_preprocessor: Any = field(default=None, init=False, repr=False)
_threshold: float | None = field(default=None, init=False, repr=False)
@property
def enabled(self) -> bool:
return self.config.enabled
def _ensure_model_loaded(self) -> None:
"""Lazy-load the frozen value function on first use."""
if self._model is not None:
return
from lerobot.rewards import (
make_reward_model,
make_reward_model_config,
make_reward_pre_post_processors,
)
cfg = make_reward_model_config(
"distributional_value_function",
pretrained_path=self.config.value_function_path,
device=self.config.device,
)
self._model = make_reward_model(cfg)
self._model.eval()
for p in self._model.parameters():
p.requires_grad_(False)
self._preprocessor, _ = make_reward_pre_post_processors(cfg)
logger.info("Loaded frozen VF from %s on %s", self.config.value_function_path, self.config.device)
def compute_advantages_for_episode(self, record: EpisodeRecord) -> tuple[np.ndarray, np.ndarray]:
"""Compute raw advantage values for all frames in an episode.
Returns:
(advantages, intervention_mask) both shape [num_frames].
advantages[t] = A_t, intervention_mask[t] = True if frame is intervention.
"""
self._ensure_model_loaded()
df = record.frames_df()
num_frames = len(df)
mc_return_key = self.config.mc_return_key
if mc_return_key not in df.columns:
raise KeyError(
f"Column '{mc_return_key}' not found in episode {record.episode_index}. "
"Run lerobot-compute-returns first."
)
mc_returns = df[mc_return_key].values.astype(np.float32)
intervention_mask = np.zeros(num_frames, dtype=bool)
if self.config.intervention_key in df.columns:
intervention_mask = df[self.config.intervention_key].values.astype(bool)
# Skip VF inference on intervention frames — they're always "positive"
# regardless of advantage value, so V(s_t) is never used for them.
skip_mask = intervention_mask if self.config.force_positive_on_intervention else None
values = self._compute_values(record, skip_mask=skip_mask)
if self.config.n_step is None:
advantages = mc_returns - values
else:
advantages = self._compute_n_step_advantages(mc_returns, values, record, n=self.config.n_step)
return advantages, intervention_mask
def _compute_values(self, record: EpisodeRecord, skip_mask: np.ndarray | None = None) -> np.ndarray:
"""Run frozen VF over all frames to get V(s_t) predictions.
Args:
record: Episode data.
skip_mask: Optional boolean mask [num_frames]. Frames where True are
skipped (left as 0.0) to avoid unnecessary inference.
"""
df = record.frames_df()
num_frames = len(df)
values = np.zeros(num_frames, dtype=np.float32)
image_key = self._resolve_image_key(df)
if image_key is None:
logger.warning("No image key found for episode %d; returning zero values.", record.episode_index)
return values
# Determine which frame indices actually need inference
infer_indices = np.where(~skip_mask)[0] if skip_mask is not None else np.arange(num_frames)
if len(infer_indices) == 0:
return values
task_text = record.episode_task
for batch_start in range(0, len(infer_indices), self.config.batch_size):
batch_end = min(batch_start + self.config.batch_size, len(infer_indices))
batch_indices = infer_indices[batch_start:batch_end]
batch_images = []
for idx in batch_indices:
img_val = df.iloc[idx][image_key]
if isinstance(img_val, np.ndarray):
img_tensor = torch.from_numpy(img_val).float()
elif isinstance(img_val, torch.Tensor):
img_tensor = img_val.float()
else:
img_tensor = torch.zeros(3, 224, 224)
batch_images.append(img_tensor)
batch_images_tensor = torch.stack(batch_images)
batch_size = batch_images_tensor.shape[0]
raw_batch = {
image_key: batch_images_tensor,
"task": [task_text] * batch_size,
}
processed = self._preprocessor(raw_batch)
with torch.no_grad():
v_values = self._model.compute_reward(processed)
values[batch_indices] = v_values.cpu().numpy()
return values
def _compute_n_step_advantages(
self, mc_returns: np.ndarray, values: np.ndarray, record: EpisodeRecord, n: int
) -> np.ndarray:
"""Compute N-step advantage: A_t = Σ r_{t:t+N-1} + V(s_{t+N}) - V(s_t).
When t+N exceeds episode length, truncates to MC (uses mc_return directly).
"""
num_frames = len(values)
advantages = np.zeros(num_frames, dtype=np.float32)
for t in range(num_frames):
if t + n >= num_frames:
advantages[t] = mc_returns[t] - values[t]
else:
n_step_return = mc_returns[t] - mc_returns[t + n]
advantages[t] = n_step_return + values[t + n] - values[t]
return advantages
def _resolve_image_key(self, df) -> str | None:
"""Find the first image observation key in the dataframe columns."""
for col in df.columns:
if col.startswith("observation.images."):
return col
return None
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
"""Score one episode and write advantage rows to staging."""
if not self.config.value_function_path:
logger.warning("No value_function_path configured; skipping advantage scoring.")
return
advantages, intervention_mask = self.compute_advantages_for_episode(record)
num_frames = len(advantages)
threshold = self._compute_threshold(advantages, intervention_mask)
rng = np.random.default_rng(seed=hash((record.episode_index, 42)) & 0xFFFFFFFF)
rows: list[dict[str, Any]] = []
for t in range(num_frames):
if rng.random() < self.config.dropout_rate:
continue
if (
self.config.force_positive_on_intervention
and intervention_mask[t]
or advantages[t] > threshold
):
indicator = "positive"
else:
indicator = "negative"
timestamp = float(record.frame_timestamps[t]) if t < len(record.frame_timestamps) else 0.0
rows.append(
{
"role": "user",
"content": indicator,
"style": "advantage",
"timestamp": timestamp,
"camera": None,
"tool_calls": None,
}
)
staging.write("advantage", rows)
logger.debug(
"Episode %d: %d/%d frames scored (threshold=%.4f, %d positive, %d negative)",
record.episode_index,
len(rows),
num_frames,
threshold,
sum(1 for r in rows if r["content"] == "positive"),
sum(1 for r in rows if r["content"] == "negative"),
)
def _compute_threshold(self, advantages: np.ndarray, intervention_mask: np.ndarray) -> float:
"""Compute the binarization threshold as the configured percentile of advantages."""
non_intervention = advantages[~intervention_mask] if intervention_mask.any() else advantages
if len(non_intervention) == 0:
return 0.0
return float(np.percentile(non_intervention, self.config.threshold_percentile * 100))
@@ -39,6 +39,7 @@ _MODULES: tuple[ModuleName, ...] = (
"plan",
"interjections",
"vqa",
"advantage",
)
+1
View File
@@ -32,6 +32,7 @@ DEFAULT_BINDINGS = {
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
"advantage": "active_at(t, style=advantage)",
}
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
@@ -0,0 +1,30 @@
# RECAP advantage-conditioned recipe.
#
# Composes task + advantage indicator into the prompt for conditional SFT.
# The advantage binding resolves to "positive" or "negative" from the
# language_persistent column (written by lerobot-annotate --advantage).
# When advantage is absent (30% dropout), the advantage turn is skipped
# entirely via if_present, training the unconditional branch for CFG.
#
# This recipe is policy-agnostic: any VLA that consumes chat-style messages
# can use it. Override bindings or add blend components for task-specific needs.
#
# Paper: pi*0.6, Section IV-B (conditional policy training with I_t).
bindings:
advantage: "active_at(t, style=advantage)"
messages:
- role: user
content: "${task}"
stream: high_level
- role: user
content: "Advantage: ${advantage}"
stream: high_level
if_present: advantage
- role: assistant
content: "${subtask}"
stream: low_level
target: true
@@ -0,0 +1,41 @@
# RECAP full recipe with advantage conditioning and subtask blending.
#
# Blend of two training modes:
# 1. advantage_conditioned (70%): Task + advantage indicator → action
# 2. unconditional (30%): Task only → action (no advantage, trains CFG baseline)
#
# This achieves the same effect as per-frame dropout in the annotation module
# but at the recipe level, giving explicit control over the conditioning ratio.
# Use this instead of annotation-level dropout if you want a fixed split.
#
# Paper: pi*0.6, Appendix E (classifier-free guidance requires both branches).
blend:
advantage_conditioned:
weight: 0.7
messages:
- role: user
content: "${task}\nAdvantage: ${advantage}"
stream: high_level
if_present: advantage
- role: user
content: "${task}"
stream: high_level
- role: assistant
content: "${subtask}"
stream: low_level
target: true
unconditional:
weight: 0.3
messages:
- role: user
content: "${task}"
stream: high_level
- role: assistant
content: "${subtask}"
stream: low_level
target: true
+2 -2
View File
@@ -43,10 +43,10 @@ CORE_STYLES = {
# validation. Empty by default — populate from a downstream module that
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
# the new style's column.
EXTENDED_STYLES: set[str] = set()
EXTENDED_STYLES: set[str] = {"advantage"}
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug", "advantage"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
@@ -87,6 +87,17 @@ class PI05Config(PreTrainedConfig):
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Language conditioning (e.g. RECAP advantage). When set, RenderMessagesStep
# is inserted into the preprocessor to resolve language_persistent rows via
# the recipe YAML before prompt construction.
recipe_path: str | None = None
# Classifier-Free Guidance (CFG) scale for inference (Eq. 13 in RECAP paper).
# 1.0 = no guidance (default). >1.0 enables dual-path denoising where:
# v = v_uncond + cfg_beta * (v_cond - v_uncond)
# VLM runs twice (cond + uncond prompts), action expert runs 2x per step.
cfg_beta: float = 1.0
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
+141 -2
View File
@@ -52,6 +52,8 @@ from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
OBS_LANGUAGE_UNCOND_TOKENS,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -148,6 +150,20 @@ def clone_past_key_values(past_key_values):
)
def cat_past_key_values(kv_a, kv_b):
"""Concatenate two DynamicCaches along the batch dimension for batched CFG."""
return DynamicCache(
tuple(
(
torch.cat([ka, kb], dim=0),
torch.cat([va, vb], dim=0),
sw_a,
)
for (ka, va, sw_a), (kb, vb, _sw_b) in zip(kv_a, kv_b, strict=True)
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -797,9 +813,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
masks,
noise=None,
num_steps=None,
uncond_tokens=None,
uncond_masks=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
"""Do a full inference forward and compute the action.
When cfg_beta > 1.0 and uncond_tokens/uncond_masks are provided, performs
Classifier-Free Guidance: VLM runs twice (conditioned + unconditional), action
expert runs twice per denoising step, and velocities are interpolated via
v = v_uncond + cfg_beta * (v_cond - v_uncond).
"""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -815,6 +839,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
cfg_enabled = self.config.cfg_beta > 1.0 and uncond_tokens is not None and uncond_masks is not None
# Prefill VLM for conditioned prompt
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
@@ -830,6 +857,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_cache=True,
)
# Prefill VLM for unconditional prompt (CFG)
if cfg_enabled:
uncond_prefix_embs, uncond_prefix_pad_masks, uncond_prefix_att_masks = self.embed_prefix(
images, img_masks, uncond_tokens, uncond_masks
)
uncond_prefix_att_2d_masks = make_att_2d_masks(uncond_prefix_pad_masks, uncond_prefix_att_masks)
uncond_prefix_position_ids = torch.cumsum(uncond_prefix_pad_masks, dim=1) - 1
uncond_prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(uncond_prefix_att_2d_masks)
_, uncond_past_key_values = self.paligemma_with_expert.forward(
attention_mask=uncond_prefix_att_2d_masks_4d,
position_ids=uncond_prefix_position_ids,
past_key_values=None,
inputs_embeds=[uncond_prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
x_t = noise
@@ -838,6 +882,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
if cfg_enabled:
return self.denoise_step_cfg_batched(
cond_prefix_pad_masks=prefix_pad_masks,
cond_past_key_values=past_key_values,
uncond_prefix_pad_masks=uncond_prefix_pad_masks,
uncond_past_key_values=uncond_past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
@@ -907,6 +960,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)
def denoise_step_cfg_batched(
self,
cond_prefix_pad_masks,
cond_past_key_values,
uncond_prefix_pad_masks,
uncond_past_key_values,
x_t,
timestep,
):
"""Batched CFG denoising: runs cond + uncond in a single forward pass.
Concatenates cond and uncond inputs along the batch dimension, runs one
action expert forward (2x batch), then splits and applies CFG interpolation.
This is ~1.5x faster than two sequential denoise_step calls due to better
GPU utilization (inspired by Qwen2.5-Omni DiT / diffusers batched CFG).
"""
# Embed suffix once (same x_t and timestep for both branches)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
bsize = cond_prefix_pad_masks.shape[0]
suffix_len = suffix_pad_masks.shape[1]
cond_prefix_len = cond_prefix_pad_masks.shape[1]
uncond_prefix_len = uncond_prefix_pad_masks.shape[1]
# Build attention masks for cond branch
cond_prefix_2d = cond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, cond_prefix_len)
cond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
cond_full_att = torch.cat([cond_prefix_2d, cond_suffix_att_2d], dim=2)
cond_prefix_offsets = torch.sum(cond_prefix_pad_masks, dim=-1)[:, None]
cond_position_ids = cond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Build attention masks for uncond branch
uncond_prefix_2d = uncond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, uncond_prefix_len)
uncond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
uncond_full_att = torch.cat([uncond_prefix_2d, uncond_suffix_att_2d], dim=2)
uncond_prefix_offsets = torch.sum(uncond_prefix_pad_masks, dim=-1)[:, None]
uncond_position_ids = uncond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Concatenate on batch dim: [cond_batch; uncond_batch]
batched_full_att = torch.cat([cond_full_att, uncond_full_att], dim=0)
batched_full_att_4d = self._prepare_attention_masks_4d(batched_full_att)
batched_position_ids = torch.cat([cond_position_ids, uncond_position_ids], dim=0)
batched_suffix_embs = torch.cat([suffix_embs, suffix_embs], dim=0)
batched_adarms_cond = torch.cat([adarms_cond, adarms_cond], dim=0)
# Concatenate KV caches on batch dim
batched_past_kv = cat_past_key_values(
clone_past_key_values(cond_past_key_values),
clone_past_key_values(uncond_past_key_values),
)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
# Single forward pass for both branches
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=batched_full_att_4d,
position_ids=batched_position_ids,
past_key_values=batched_past_kv,
inputs_embeds=[None, batched_suffix_embs],
use_cache=False,
adarms_cond=[None, batched_adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_all = self.action_out_proj(suffix_out)
# Split: first half = cond, second half = uncond
v_cond, v_uncond = v_all.chunk(2, dim=0)
# CFG interpolation: v = v_uncond + beta * (v_cond - v_uncond)
return v_uncond + self.config.cfg_beta * (v_cond - v_uncond)
class PI05Policy(PreTrainedPolicy):
"""PI05 Policy for LeRobot."""
@@ -1243,8 +1370,20 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# CFG: pass unconditional tokens if available
uncond_tokens = batch.get(f"{OBS_LANGUAGE_UNCOND_TOKENS}")
uncond_masks = batch.get(f"{OBS_LANGUAGE_UNCOND_ATTENTION_MASK}")
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
actions = self.model.sample_actions(
images,
img_masks,
tokens,
masks,
uncond_tokens=uncond_tokens,
uncond_masks=uncond_masks,
**kwargs,
)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
+69 -15
View File
@@ -40,6 +40,8 @@ from lerobot.processor import (
)
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
OBS_LANGUAGE_UNCOND_TOKENS,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
@@ -57,6 +59,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
max_state_dim: int = 32
task_key: str = "task"
cfg_enabled: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
@@ -84,8 +87,25 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
# Build unconditional prompts for CFG (same state but original task without advantage)
if self.cfg_enabled:
base_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("base_task")
if base_tasks is None:
base_tasks = tasks
if isinstance(base_tasks, str):
base_tasks = [base_tasks] * len(tasks)
uncond_prompts = []
for i, base_task in enumerate(base_tasks):
cleaned_text = base_task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
uncond_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
uncond_prompts.append(uncond_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA]["uncond_task"] = uncond_prompts
return transition
def transform_features(
@@ -111,9 +131,10 @@ def make_pi05_pre_post_processors(
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
4. (Optional) Rendering language annotations via recipe YAML.
5. (Optional) Flattening rendered messages into the task string.
6. Tokenizing the text prompt using the PaliGemma tokenizer.
7. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
@@ -122,8 +143,6 @@ def make_pi05_pre_post_processors(
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
@@ -147,16 +166,51 @@ def make_pi05_pre_post_processors(
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
# Insert language rendering steps when a recipe is configured (e.g. RECAP advantage)
if config.recipe_path is not None:
from lerobot.configs.recipe import load_recipe
from lerobot.processor.render_messages_processor import RenderMessagesStep
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep
recipe = load_recipe(config.recipe_path)
input_steps.append(RenderMessagesStep(recipe=recipe))
input_steps.append(RenderedMessagesToTaskStep())
cfg_enabled = config.cfg_beta > 1.0
input_steps.extend(
[
Pi05PrepareStateTokenizerProcessorStep(
max_state_dim=config.max_state_dim,
cfg_enabled=cfg_enabled,
),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
]
)
# Add unconditional prompt tokenizer for CFG inference
if cfg_enabled:
input_steps.append(
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
task_key="uncond_task",
output_tokens_key=OBS_LANGUAGE_UNCOND_TOKENS,
output_mask_key=OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
@@ -0,0 +1,86 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adapter step that flattens rendered chat messages back into a task string.
Bridges RenderMessagesStep (which outputs structured messages) to policies
that expect a plain task string in complementary_data["task"] (e.g. PI05).
"""
from __future__ import annotations
from lerobot.configs import PipelineFeatureType, PolicyFeature
from .pipeline import ComplementaryDataProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register(name="rendered_messages_to_task")
class RenderedMessagesToTaskStep(ComplementaryDataProcessorStep):
"""Extract user-role message content from rendered messages into the task string.
After RenderMessagesStep renders a recipe into structured messages, this
step extracts content from all user-role messages, joins them, and writes
the result to complementary_data["task"]. This allows downstream steps
(like Pi05PrepareStateTokenizerProcessorStep) to consume the
advantage-conditioned prompt without modification.
No-ops when the "messages" key is absent (backward compatible with
pipelines that don't use language annotations).
"""
def complementary_data(self, complementary_data: dict) -> dict:
messages = complementary_data.get("messages")
if messages is None:
return complementary_data
user_parts = []
for msg in messages:
if msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str) and content:
user_parts.append(content)
elif isinstance(content, list):
# HF multimodal blocks: extract text blocks
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text", "")
if text:
user_parts.append(text)
new_complementary_data = dict(complementary_data)
if user_parts:
task = complementary_data.get("task")
# Preserve the original task for CFG unconditional prompt
new_complementary_data["base_task"] = task
# Wrap in list if the original task was a list (batched)
joined = "\n".join(user_parts)
if isinstance(task, list):
new_complementary_data["task"] = [joined] * len(task)
else:
new_complementary_data["task"] = joined
# Remove consumed rendering outputs
new_complementary_data.pop("messages", None)
new_complementary_data.pop("message_streams", None)
new_complementary_data.pop("target_message_indices", None)
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+8 -6
View File
@@ -81,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
output_tokens_key: str = OBS_LANGUAGE_TOKENS
output_mask_key: str = OBS_LANGUAGE_ATTENTION_MASK
# Internal tokenizer instance (not part of the config)
input_tokenizer: Any = field(default=None, init=False, repr=False)
@@ -201,8 +203,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation = dict(observation)
# Add tokenized data to the observation
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
new_observation[self.output_tokens_key] = tokenized_prompt["input_ids"]
new_observation[self.output_mask_key] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
@@ -309,14 +311,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
The updated dictionary of policy features.
"""
# Add a feature for the token IDs if it doesn't already exist
if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
if self.output_tokens_key not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][self.output_tokens_key] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
# Add a feature for the attention mask if it doesn't already exist
if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
if self.output_mask_key not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][self.output_mask_key] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
+4
View File
@@ -13,6 +13,9 @@
# limitations under the License.
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig as DistributionalVFConfig,
)
from .factory import (
get_reward_model_class as get_reward_model_class,
make_reward_model as make_reward_model,
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
__all__ = [
# Configuration classes
"DistributionalVFConfig",
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
@@ -0,0 +1,23 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_distributional_value_function import DistributionalVFConfig
from .modeling_distributional_value_function import DistributionalVFRewardModel
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
__all__ = [
"DistributionalVFConfig",
"DistributionalVFRewardModel",
"make_distributional_vf_pre_post_processors",
]
@@ -0,0 +1,108 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weights initialized from a pre-trained PI05 actor checkpoint.
"""
from dataclasses import dataclass, field
from lerobot.configs import FeatureType, NormalizationMode
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
@RewardModelConfig.register_subclass("distributional_value_function")
@dataclass
class DistributionalVFConfig(RewardModelConfig):
"""Configuration for RECAP's distributional value function.
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
(Eq. 1 in the paper).
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
about 670M params and initialize from the PI05 actor checkpoint.
"""
# Backbone
paligemma_variant: str = "gemma_2b"
num_hidden_layers: int = 6
num_vision_layers: int = 13
# Distributional head
num_value_bins: int = 201
value_support_min: float = -1.0
value_support_max: float = 0.0
hl_gauss_sigma_ratio: float = 5.0
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
target_method: str = "hl_gauss"
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
# When False, terminal states use the same target method as non-terminal states.
use_one_hot_terminal: bool = True
# Image
image_resolution: tuple[int, int] = (224, 224)
# Tokenizer
tokenizer_max_length: int = 64
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
# Pass a PI05 checkpoint path or Hub repo_id here.
# After training, load the value function with RewardModel.from_pretrained() instead.
init_from_actor_path: str = ""
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
}
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=3e-4,
weight_decay=1e-4,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
if not self.input_features:
return
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
if not has_image:
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
@@ -0,0 +1,567 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Inputs: single image observation + task text prompt ("Task: {task}.")
Outputs: softmax distribution over value bins; expected value E[V] for inference.
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weight initialization: vision tower, multi-modal projector, token embeddings, and
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.import_utils import _transformers_available, require_package
from .configuration_distributional_value_function import DistributionalVFConfig
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaRMSNorm,
_gated_residual,
_get_pi_gemma_decoder_layer_base,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
PaliGemmaForConditionalGenerationWithPiGemma = None
PiGemmaRMSNorm = None
_gated_residual = None
_get_pi_gemma_decoder_layer_base = None
PALIGEMMA_VOCAB_SIZE = 257152
class DistributionalVFRewardModel(PreTrainedRewardModel):
"""Distributional value function model for RECAP.
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
per-task normalized Monte Carlo returns.
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
causal attention, [CLS] token, and Linear(D, num_bins) value head.
The expected value is E[V] = sum(softmax(logits) * bin_centers).
"""
name = "distributional_value_function"
config_class = DistributionalVFConfig
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
require_package("transformers", extra="recap")
super().__init__(config)
self.config = config
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
# Get base dimensions from the paligemma variant (OpenPI config format)
base_config = get_gemma_config(config.paligemma_variant)
hidden_dim = base_config.width
mlp_dim = base_config.mlp_dim
num_layers = config.num_hidden_layers
# HuggingFace GemmaConfig for transformer layers
gemma_config = CONFIG_MAPPING["gemma"](
head_dim=base_config.head_dim,
hidden_size=hidden_dim,
intermediate_size=mlp_dim,
num_attention_heads=base_config.num_heads,
num_hidden_layers=num_layers,
num_key_value_heads=base_config.num_kv_heads,
vocab_size=PALIGEMMA_VOCAB_SIZE,
hidden_activation="gelu_pytorch_tanh",
)
self.gemma_config = gemma_config
self.hidden_dim = hidden_dim
self.num_value_bins = config.num_value_bins
# Single learned [CLS] token for value prediction
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
# Value projection head: Linear(hidden_dim, num_bins)
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
# Transformer layers (overwritten by _initialize_from_actor on first run)
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
)
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
# PaliGemmaConfig wraps both vision and text configs into a single model
paligemma_config = CONFIG_MAPPING["paligemma"]()
paligemma_config.text_config = gemma_config
paligemma_config.vision_config.image_size = config.image_resolution[0]
paligemma_config.vision_config.intermediate_size = 4304
paligemma_config.vision_config.projection_dim = 2048
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
self.vision_tower = paligemma_full.model.vision_tower
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
self.token_embedding = paligemma_full.model.language_model.embed_tokens
del paligemma_full
# Truncate vision tower to num_vision_layers
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
vision_encoder = self.vision_tower.vision_model.encoder
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
# Bin support: evenly spaced centers from value_support_min to value_support_max
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
self.register_buffer("bin_centers", bin_centers, persistent=False)
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
# Overwrite with pre-trained PI05 actor weights (first training run only)
if config.init_from_actor_path:
self._initialize_from_actor()
def _initialize_from_actor(self) -> None:
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
Called on first training run only (when init_from_actor_path is set).
"""
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
actor_model = actor_policy.model
paligemma_model = actor_model.paligemma_with_expert.paligemma
source_language_model = paligemma_model.model.language_model
# Transformer components
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
num_layers = self.gemma_config.num_hidden_layers
for i in range(num_layers):
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
self.norm.load_state_dict(source_language_model.norm.state_dict())
# Vision tower (truncate source first, then copy)
source_vision_tower = paligemma_model.model.vision_tower
if hasattr(source_vision_tower, "vision_model") and hasattr(
source_vision_tower.vision_model, "encoder"
):
source_encoder = source_vision_tower.vision_model.encoder
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
# Multi-modal projector
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
# Token embedding table
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
del actor_policy
def embed_image(self, image: Tensor) -> Tensor:
"""Embed images using the value function's SigLIP vision tower.
Args:
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
Returns:
[batch_size, num_patches, hidden_dim] projected image features.
"""
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.vision_tower(image, return_dict=True)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
image_features = image_features / (self.hidden_dim**0.5)
if image_features.dtype != out_dtype:
image_features = image_features.to(out_dtype)
return image_features
def embed_text(self, token_ids: Tensor) -> Tensor:
"""Embed text token IDs using the value function's token embedding table.
Args:
token_ids: [batch_size, seq_len] integer token IDs
Returns:
[batch_size, seq_len, hidden_dim] text embeddings
"""
return self.token_embedding(token_ids)
def _get_cls_embedding(self, batch_size: int) -> Tensor:
"""Get [CLS] token embedding expanded to batch size.
Args:
batch_size: number of samples in the batch.
Returns:
[batch_size, 1, hidden_dim] learned [CLS] embedding.
"""
return self.cls_embedding.expand(batch_size, -1, -1)
def forward_value(
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
) -> dict[str, Tensor]:
"""Core forward pass through the distributional value function.
Args:
vision_features: [batch_size, num_patches, hidden_dim]
text_embeddings: [batch_size, seq_len, hidden_dim]
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
Returns:
logits: [batch_size, num_value_bins]
probs: [batch_size, num_value_bins]
value: [batch_size, 1]
"""
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
batch_size = text_embeddings.shape[0]
device = text_embeddings.device
# Build sequence: [vision, text, CLS]
cls_embedding = self._get_cls_embedding(batch_size)
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
# Build causal attention mask
vision_len = vision_features.shape[1]
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
full_seq_len = full_padding_mask.shape[1]
# Causal mask
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
# Combine causal mask with padding mask
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
batch_size, 1, full_seq_len, full_seq_len
)
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
cos, sin = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
norm_output = layer.input_layernorm(hidden_states, cond=None)
if isinstance(norm_output, tuple):
hidden_states_normed, gate = norm_output
else:
hidden_states_normed, gate = norm_output, None
input_shape = hidden_states_normed.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
attention_output, _ = modeling_gemma.eager_attention_forward(
layer.self_attn,
query_states,
key_states,
value_states,
attention_mask,
layer.self_attn.scaling,
)
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
projected_attention = layer.self_attn.o_proj(attention_output)
if gate is not None:
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
else:
projected_attention = hidden_states + projected_attention
after_attention_residual = projected_attention.clone()
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
if isinstance(norm_output, tuple):
mlp_input, gate = norm_output
else:
mlp_input, gate = norm_output, None
mlp_output = layer.mlp(mlp_input)
if gate is not None:
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
else:
hidden_states = after_attention_residual + mlp_output
hidden_states = self.norm(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Extract [CLS] token (last position in the sequence)
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
# Value head: Linear(hidden_dim, num_bins) -> logits
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
value_probs = F.softmax(value_logits, dim=-1)
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
dim=-1, keepdim=True
)
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
"""HL-Gauss soft target distribution.
Places a Gaussian N(target, sigma^2) over the bin support and computes
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
Functions via Classification for Scalable Deep RL", Section 3.1.
arXiv:2403.03950
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
# Bin edges: half a bin-width outside the first/last center
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
self.num_value_bins - 1
)
support_edges = torch.linspace(
self.config.value_support_min - bin_width / 2,
self.config.value_support_max + bin_width / 2,
self.num_value_bins + 1,
device=target_value.device,
dtype=target_value.dtype,
)
# CDF of N(target, sigma^2) evaluated at each edge
cdf_at_edges = 0.5 * (
1.0
+ torch.erf(
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
/ (self.hl_gauss_sigma * math.sqrt(2))
)
) # [batch_size, num_bins + 1]
# Normalize: z = cdf(max_edge) - cdf(min_edge)
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
# Bin probabilities = differences of consecutive CDF values, normalized
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
return bin_probabilities
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
"""Dirac delta (C51) projection: split probability between two nearest bins.
Standard distributional RL projection from Bellemare et al. 2017.
"A Distributional Perspective on Reinforcement Learning"
arXiv:1707.06887
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
target_value = target_value.to(dtype=self.bin_centers.dtype)
bin_width = self.bin_centers[1] - self.bin_centers[0]
normalized_position = (target_value - self.config.value_support_min) / bin_width
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
weight_upper = normalized_position - lower_bin_idx.float()
weight_lower = upper_bin_idx.float() - normalized_position
same_bin = lower_bin_idx == upper_bin_idx
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
batch_size = target_value.shape[0]
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
batch_indices = torch.arange(batch_size, device=target_value.device)
target_distribution[batch_indices, lower_bin_idx] += weight_lower
target_distribution[batch_indices, upper_bin_idx] += weight_upper
return target_distribution
def one_hot_target(self, target_value: Tensor) -> Tensor:
"""One-hot target for terminal states (exact return, no smoothing).
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
nearest_bin_idx = torch.argmin(
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
)
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
def compute_target_distribution(
self,
target_value: Tensor,
is_terminal: Tensor,
method: str = "hl_gauss",
use_one_hot_terminal: bool = True,
) -> Tensor:
"""Compute target distribution using configured method.
Args:
target_value: [batch_size] scalar return targets
is_terminal: [batch_size] boolean terminal flags
method: "hl_gauss" or "dirac_delta"
use_one_hot_terminal: if True, terminal states get one-hot targets
(exact return, no smoothing). If False, all states use the same method.
Returns:
[batch_size, num_value_bins] target probability distribution
"""
if method == "hl_gauss":
base_distribution = self.hl_gauss_target(target_value)
elif method == "dirac_delta":
base_distribution = self.dirac_delta_target(target_value)
else:
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
if not use_one_hot_terminal:
return base_distribution
terminal_distribution = self.one_hot_target(target_value)
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Training forward pass — computes cross-entropy loss against MC return targets.
The batch is expected to be preprocessed by the processor pipeline.
Keys expected in batch:
- observation.images.*: [B, C, H, W] preprocessed images
- observation.language_tokens: [B, seq_len] tokenized task prompt
- observation.language_attention_mask: [B, seq_len] padding mask
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
- is_terminal: [B] boolean terminal flags
Returns:
(loss, output_dict) where loss is scalar cross-entropy
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
# Get first image key from batch
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
mc_return = batch["mc_return"]
is_terminal = batch["is_terminal"]
# Embed observations
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
# Forward through value function transformer
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
value_logits = vf_output["logits"]
predicted_value = vf_output["value"]
# Compute target distribution
target_distribution = self.compute_target_distribution(
mc_return,
is_terminal,
method=self.config.target_method,
use_one_hot_terminal=self.config.use_one_hot_terminal,
)
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
log_probs = F.log_softmax(value_logits, dim=-1)
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
output_dict = {
"loss": loss.item(),
"predicted_value_mean": predicted_value.mean().item(),
"mc_return_mean": mc_return.mean().item(),
}
return loss, output_dict
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute V(s) for a batch of observations. Used for advantage scoring.
Args:
batch: preprocessed batch with images and tokenized text
Returns:
[batch_size] tensor of predicted values V(s)
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
return vf_output["value"].squeeze(-1) # [batch_size]
@@ -0,0 +1,235 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
Training targets (mc_return, is_terminal) are NOT routed through the processor.
They are dataset columns read directly from the batch in the model's forward().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
)
from lerobot.processor.converters import to_tensor
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_distributional_value_function import DistributionalVFConfig
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
@dataclass
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
"""Format the task string for the distributional value function.
The value function receives only visual observations and task text.
Builds prompt: "Task: {task}."
"""
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
tasks = complementary_data.get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
if isinstance(tasks, str):
tasks = [tasks]
full_prompts = []
for task in tasks:
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
full_prompts.append(f"Task: {cleaned_text}.")
new_complementary_data = dict(complementary_data)
new_complementary_data[self.task_key] = full_prompts
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {"task_key": self.task_key}
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
@dataclass
class DistributionalVFImagePreprocessorStep(ProcessorStep):
"""Resize and normalize images for the value function's SigLIP vision tower.
Expects float images in [0, 1].
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
- Scale to [-1, 1] for SigLIP
"""
image_resolution: tuple[int, int] = (224, 224)
image_keys: tuple[str, ...] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
observation = transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
image_keys = self.image_keys or tuple(
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
)
if not image_keys:
raise KeyError(
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
)
new_observation = dict(observation)
for image_key in image_keys:
image = new_observation[image_key]
if not isinstance(image, Tensor):
image = to_tensor(image)
if image.dtype != torch.float32:
image = image.to(torch.float32)
is_channels_first = image.ndim == 4 and image.shape[1] == 3
if is_channels_first:
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != self.image_resolution:
image = resize_with_pad_torch(image, *self.image_resolution)
image = image * 2.0 - 1.0
if is_channels_first:
image = image.permute(0, 3, 1, 2)
new_observation[image_key] = image
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"image_resolution": self.image_resolution,
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
}
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
return tuple(
feature_name
for feature_name, feature in config.input_features.items()
if feature.type == FeatureType.VISUAL
)
def make_distributional_vf_pre_post_processors(
config: DistributionalVFConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create pre/post processors for the distributional value function.
Preprocessor steps:
1. Rename observations (no-op by default)
2. Add a batch dimension
3. Normalize features (images use identity, so they stay in [0, 1])
4. Format task prompt: "Task: {task}."
5. Tokenize with the PaliGemma tokenizer
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
7. Move tensors to the configured device
Training targets (mc_return, is_terminal) are not processed here.
The model reads them directly from the batch in forward().
The postprocessor is a no-op because the value function does not need
action postprocessing.
"""
image_keys = _visual_image_keys(config)
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DistributionalVFPrepareTaskPromptStep(),
TokenizerProcessorStep(
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DistributionalVFImagePreprocessorStep(
image_resolution=config.image_resolution,
image_keys=image_keys or None,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
+19
View File
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
return TOPRewardModel
elif name == "distributional_value_function":
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
elif reward_type == "distributional_value_function":
return DistributionalVFConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, DistributionalVFConfig):
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
return make_distributional_vf_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_reward_model_config(
+8
View File
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
success: str = "s"
failure: str = "f"
@dataclass
@@ -119,6 +121,8 @@ class DAggerPedalConfig:
pause_resume: str = "KEY_A"
correction: str = "KEY_B"
upload: str = "KEY_C"
success: str = "KEY_D"
failure: str = "KEY_E"
@RolloutStrategyConfig.register_subclass("episodic")
@@ -165,6 +169,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
2. **correction** — toggle human correction recording.
3. **upload** — push dataset to hub on demand (corrections-only mode).
Episode success labeling:
4. **success** — mark current episode as successful.
5. **failure** — mark current episode as failed.
When ``record_autonomous=False`` (default) only human-correction windows
are recorded — each correction becomes its own episode. Set to ``True``
to record both autonomous and correction frames with size-based episode
+5
View File
@@ -347,6 +347,11 @@ def build_rollout_context(
"shape": (1,),
"names": None,
}
dataset_features["next.success"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
if not repo_name.startswith("rollout_"):
+69 -1
View File
@@ -129,6 +129,9 @@ class DAggerEvents:
self.stop_recording = Event()
self.upload_requested = Event()
# Episode success labeling
self._episode_success: bool | None = None
# -- Thread-safe phase access ------------------------------------------
@property
@@ -171,8 +174,26 @@ class DAggerEvents:
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self._episode_success = None
self.upload_requested.clear()
def mark_success(self) -> None:
"""Mark the current episode as successful (called from input threads)."""
with self._lock:
self._episode_success = True
def mark_failure(self) -> None:
"""Mark the current episode as failed (called from input threads)."""
with self._lock:
self._episode_success = False
def consume_episode_success(self) -> bool | None:
"""Consume and reset the episode success label. Returns None if unlabeled."""
with self._lock:
result = self._episode_success
self._episode_success = None
return result
# ---------------------------------------------------------------------------
# Input device handlers
@@ -226,16 +247,25 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
events.request_transition(key_to_event[resolved])
if resolved == cfg.upload:
events.upload_requested.set()
if resolved == cfg.success:
events.mark_success()
logger.info("Episode marked as SUCCESS")
if resolved == cfg.failure:
events.mark_failure()
logger.info("Episode marked as FAILURE")
except Exception as e:
logger.debug("Key error: %s", e)
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
"DAgger keyboard listener started (pause_resume='%s', correction='%s', "
"upload='%s', success='%s', failure='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
cfg.upload,
cfg.success,
cfg.failure,
)
return listener
@@ -255,6 +285,12 @@ def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
events.request_transition(code_to_event[code])
if code == cfg.upload:
events.upload_requested.set()
if code == cfg.success:
events.mark_success()
logger.info("Episode marked as SUCCESS (pedal)")
if code == cfg.failure:
events.mark_failure()
logger.info("Episode marked as FAILURE (pedal)")
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
return start_pedal_listener(on_press, device_path=cfg.device_path)
@@ -357,6 +393,31 @@ class DAggerStrategy(RolloutStrategy):
)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Episode success labeling
# ------------------------------------------------------------------
def _stamp_episode_success(self, dataset) -> None:
"""Set next.success on the terminal frame based on operator label.
Called just before save_episode(). If the operator pressed the success
key during this episode, the last frame's next.success is set to True.
Otherwise all frames remain False (unlabeled = assumed failure).
"""
buf = dataset.writer.episode_buffer
if buf is None:
return
success_buf = buf.get("next.success")
if not success_buf:
return
label = self._events.consume_episode_success()
if label:
success_buf[-1] = np.array([True], dtype=bool)
logger.info("Terminal frame stamped next.success=True")
# ------------------------------------------------------------------
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
@@ -443,6 +504,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
@@ -471,6 +533,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
@@ -481,6 +544,7 @@ class DAggerStrategy(RolloutStrategy):
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
@@ -510,6 +574,7 @@ class DAggerStrategy(RolloutStrategy):
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
@@ -584,6 +649,7 @@ class DAggerStrategy(RolloutStrategy):
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
recorded += 1
self._needs_push.set()
@@ -625,6 +691,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
)
record_tick += 1
@@ -659,6 +726,7 @@ class DAggerStrategy(RolloutStrategy):
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
+3
View File
@@ -34,6 +34,7 @@ from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConf
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
from lerobot.annotations.steerable_pipeline.modules import (
AdvantageModule,
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
@@ -86,6 +87,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
)
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
advantage = AdvantageModule(config=cfg.advantage)
writer = LanguageColumnsWriter()
validator = StagingValidator(
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
@@ -96,6 +98,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
plan=plan,
interjections=interjections,
vqa=vqa,
advantage=advantage,
writer=writer,
validator=validator,
)
@@ -0,0 +1,382 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute per-frame ``is_terminal`` and ``mc_return`` for a LeRobot dataset.
Implements the sparse reward function from pi*0.6 / RECAP (Eq. 5):
r_t = -1 for non-terminal steps
r_T = 0 for terminal success
r_T = -C_fail for terminal failure
Monte Carlo returns are the cumulative sum from each step to the end of
the episode, normalized by ``max_episode_length`` so that values are bounded
to approximately (-1, 0).
The columns are written directly into the dataset's parquet data shards as
flat per-frame scalars. These serve as training targets for the distributional
value function.
Usage:
# Compute returns using the default "next.success" column (from lerobot-eval/rollout)
lerobot-compute-returns \\
--dataset-repo-id lerobot/aloha_sim_insertion_human_image
# Override: treat all episodes as successful (demo-only datasets)
lerobot-compute-returns \\
--dataset-repo-id lerobot/aloha_sim_insertion_human_image \\
--default-success true
# Custom success key, failure penalty, and discount
lerobot-compute-returns \\
--dataset-repo-id my_org/my_dataset \\
--success-key episode_success \\
--c-fail 100 \\
--gamma 0.99
"""
from __future__ import annotations
import argparse
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
logger = logging.getLogger(__name__)
IS_TERMINAL_COL = "is_terminal"
MC_RETURN_COL = "mc_return"
@dataclass
class ComputeReturnsConfig:
"""Configuration for the returns computation script."""
dataset_repo_id: str = ""
root: str | None = None
success_key: str = "next.success"
default_success: bool | None = None
max_episode_length: int | None = None
c_fail: float = 50.0
gamma: float = 1.0
episodes: list[int] = field(default_factory=list)
force: bool = False
def _get_episode_success(
episode_table: pa.Table,
success_key: str,
default_success: bool | None,
) -> bool:
"""Determine whether an episode was successful.
Priority:
1. If ``default_success`` is set, use it unconditionally.
2. Look for ``success_key`` in the parquet columns and reduce with any().
3. Fall back to True (assume success for demo datasets).
"""
if default_success is not None:
return default_success
if success_key in episode_table.column_names:
col = episode_table.column(success_key)
for val in col:
py_val = val.as_py()
if isinstance(py_val, bool) and py_val:
return True
if isinstance(py_val, (int, float)) and py_val:
return True
return False
return True
def compute_episode_returns(
num_frames: int,
success: bool,
c_fail: float,
gamma: float,
max_episode_length: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute is_terminal and mc_return arrays for a single episode.
Args:
num_frames: Number of frames in the episode.
success: Whether the episode ended successfully.
c_fail: Failure penalty constant.
gamma: Discount factor (1.0 = undiscounted).
max_episode_length: Normalization horizon H.
Returns:
Tuple of (is_terminal, mc_return) arrays, each of length num_frames.
"""
horizon = max_episode_length
rewards = np.full(num_frames, -1.0 / horizon, dtype=np.float64)
if success:
rewards[-1] = 0.0
else:
rewards[-1] = -c_fail / horizon
is_terminal = np.zeros(num_frames, dtype=bool)
is_terminal[-1] = True
if gamma == 1.0:
# Reverse cumulative sum
mc_return = np.cumsum(rewards[::-1])[::-1].astype(np.float32)
else:
mc_return = np.zeros(num_frames, dtype=np.float64)
mc_return[-1] = rewards[-1]
for t in range(num_frames - 2, -1, -1):
mc_return[t] = rewards[t] + gamma * mc_return[t + 1]
mc_return = mc_return.astype(np.float32)
return is_terminal, mc_return
def compute_returns(config: ComputeReturnsConfig) -> Path:
"""Compute returns and write them into parquet shards."""
from lerobot.datasets import LeRobotDataset
logger.info(f"Loading dataset: {config.dataset_repo_id}")
kwargs = {"repo_id": config.dataset_repo_id, "download_videos": False}
if config.root:
kwargs["root"] = config.root
dataset = LeRobotDataset(**kwargs)
meta = dataset.meta
root = Path(meta.root)
logger.info(f"Dataset root: {root}")
logger.info(f"Episodes: {meta.total_episodes}, Frames: {meta.total_frames}")
episode_indices = config.episodes if config.episodes else list(range(meta.total_episodes))
if config.max_episode_length is not None:
max_ep_len = config.max_episode_length
else:
max_ep_len = max(int(meta.episodes[i]["length"]) for i in episode_indices)
logger.info(f"Normalization horizon (max_episode_length): {max_ep_len}")
parquet_files_to_rewrite: dict[Path, list[int]] = {}
for ep_idx in episode_indices:
rel_path = meta.get_data_file_path(ep_idx)
abs_path = root / rel_path
parquet_files_to_rewrite.setdefault(abs_path, []).append(ep_idx)
logger.info(f"Parquet shards to rewrite: {len(parquet_files_to_rewrite)}")
for parquet_path, ep_indices_in_file in tqdm(parquet_files_to_rewrite.items(), desc="Processing shards"):
table = pq.read_table(parquet_path)
if not config.force and IS_TERMINAL_COL in table.column_names:
logger.info(f"Skipping {parquet_path.name} (already has {IS_TERMINAL_COL})")
continue
all_is_terminal = np.zeros(len(table), dtype=bool)
all_mc_return = np.zeros(len(table), dtype=np.float32)
episode_col = table.column("episode_index").to_pylist()
for ep_idx in ep_indices_in_file:
ep_info = meta.episodes[ep_idx]
ep_from = int(ep_info["dataset_from_index"])
ep_to = int(ep_info["dataset_to_index"])
ep_len = ep_to - ep_from
mask = np.array([v == ep_idx for v in episode_col], dtype=bool)
local_indices = np.where(mask)[0]
if len(local_indices) != ep_len:
logger.warning(
f"Episode {ep_idx}: expected {ep_len} frames in shard, "
f"found {len(local_indices)}. Using found count."
)
ep_len = len(local_indices)
if ep_len == 0:
continue
ep_subtable = table.filter(mask)
success = _get_episode_success(ep_subtable, config.success_key, config.default_success)
is_terminal, mc_return = compute_episode_returns(
num_frames=ep_len,
success=success,
c_fail=config.c_fail,
gamma=config.gamma,
max_episode_length=max_ep_len,
)
all_is_terminal[local_indices] = is_terminal
all_mc_return[local_indices] = mc_return
if IS_TERMINAL_COL in table.column_names:
table = table.drop(IS_TERMINAL_COL)
if MC_RETURN_COL in table.column_names:
table = table.drop(MC_RETURN_COL)
table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal))
table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return))
pq.write_table(table, parquet_path)
_update_info_json(root, meta)
logger.info("Done. Columns written: is_terminal, mc_return")
return root
def _update_info_json(root: Path, meta) -> None:
"""Add is_terminal and mc_return to the dataset's info.json features."""
info_path = root / "meta" / "info.json"
if not info_path.exists():
logger.warning(f"info.json not found at {info_path}, skipping metadata update.")
return
info = json.loads(info_path.read_text())
features = info.get("features", {})
changed = False
if IS_TERMINAL_COL not in features:
features[IS_TERMINAL_COL] = {
"dtype": "bool",
"shape": [1],
"names": None,
}
changed = True
if MC_RETURN_COL not in features:
features[MC_RETURN_COL] = {
"dtype": "float32",
"shape": [1],
"names": None,
}
changed = True
if changed:
info["features"] = features
info_path.write_text(json.dumps(info, indent=2) + "\n")
logger.info("Updated meta/info.json with is_terminal and mc_return features.")
def main():
parser = argparse.ArgumentParser(
description="Compute per-frame is_terminal and mc_return for a LeRobot dataset.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Use the 'success' column from the dataset
lerobot-compute-returns --dataset-repo-id lerobot/aloha_sim_insertion_human_image
# Override all episodes as successful (demo-only data)
lerobot-compute-returns --dataset-repo-id my_org/my_dataset --default-success true
# Custom failure penalty
lerobot-compute-returns --dataset-repo-id my_org/my_dataset --c-fail 100
""",
)
parser.add_argument(
"--dataset-repo-id",
type=str,
required=True,
help="HuggingFace dataset repo id or local path.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="Local root directory override for the dataset.",
)
parser.add_argument(
"--success-key",
type=str,
default="next.success",
help="Column name in parquet that indicates episode success (default: 'next.success').",
)
parser.add_argument(
"--default-success",
type=str,
default=None,
choices=["true", "false"],
help="Override success for all episodes ('true' or 'false').",
)
parser.add_argument(
"--max-episode-length",
type=int,
default=None,
help="Normalization horizon H. If not set, uses max episode length in dataset.",
)
parser.add_argument(
"--c-fail",
type=float,
default=50.0,
help="Failure penalty constant (default: 50.0).",
)
parser.add_argument(
"--gamma",
type=float,
default=1.0,
help="Discount factor (default: 1.0, undiscounted).",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Process only these episode indices (default: all).",
)
parser.add_argument(
"--force",
action="store_true",
help="Overwrite existing is_terminal/mc_return columns.",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
default_success = None
if args.default_success is not None:
default_success = args.default_success.lower() == "true"
config = ComputeReturnsConfig(
dataset_repo_id=args.dataset_repo_id,
root=args.root,
success_key=args.success_key,
default_success=default_success,
max_episode_length=args.max_episode_length,
c_fail=args.c_fail,
gamma=args.gamma,
episodes=args.episodes or [],
force=args.force,
)
root = compute_returns(config)
logger.info(f"Returns computed and written to: {root}")
logger.info(f" Columns added: {IS_TERMINAL_COL}, {MC_RETURN_COL}")
logger.info("To train the distributional value function, these columns")
logger.info("will be read as flat batch keys during training.")
if __name__ == "__main__":
main()
+3
View File
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_UNCOND = OBS_STR + ".language_uncond"
OBS_LANGUAGE_UNCOND_TOKENS = OBS_LANGUAGE_UNCOND + ".tokens"
OBS_LANGUAGE_UNCOND_ATTENTION_MASK = OBS_LANGUAGE_UNCOND + ".attention_mask"
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
+3 -1
View File
@@ -28,9 +28,10 @@ import sys
import tempfile
from pathlib import Path
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig, AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
AdvantageModule,
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
@@ -85,6 +86,7 @@ def main() -> int:
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
+305
View File
@@ -0,0 +1,305 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the advantage scoring annotation module."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig
from lerobot.annotations.steerable_pipeline.modules.advantage import AdvantageModule
from lerobot.annotations.steerable_pipeline.reader import EpisodeRecord
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
def _make_record(
episode_index: int = 0,
num_frames: int = 20,
task: str = "pick up the cup",
mc_returns: np.ndarray | None = None,
intervention_mask: np.ndarray | None = None,
fps: float = 10.0,
) -> EpisodeRecord:
"""Build a minimal EpisodeRecord with a mocked frames_df."""
import pandas as pd
timestamps = tuple(round(i / fps, 6) for i in range(num_frames))
frame_indices = tuple(range(num_frames))
if mc_returns is None:
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
data = {
"episode_index": [episode_index] * num_frames,
"frame_index": list(range(num_frames)),
"timestamp": list(timestamps),
"mc_return": mc_returns,
}
if intervention_mask is not None:
data["intervention"] = intervention_mask.astype(bool)
df = pd.DataFrame(data)
record = EpisodeRecord(
episode_index=episode_index,
episode_task=task,
frame_timestamps=timestamps,
frame_indices=frame_indices,
data_path=Path("/fake/data.parquet"),
row_offset=0,
row_count=num_frames,
)
record._frames_df_cache = df
return record
@pytest.fixture
def staging(tmp_path: Path) -> EpisodeStaging:
return EpisodeStaging(tmp_path, episode_index=0)
def test_advantage_module_disabled():
"""Disabled module has enabled=False."""
cfg = AdvantageConfig(enabled=False)
module = AdvantageModule(config=cfg)
assert not module.enabled
def test_advantage_module_enabled_by_default():
"""Module is enabled by default."""
cfg = AdvantageConfig()
module = AdvantageModule(config=cfg)
assert module.enabled
def test_run_episode_skips_without_value_function_path(staging: EpisodeStaging):
"""Module gracefully returns when no value_function_path is configured."""
cfg = AdvantageConfig(value_function_path="")
module = AdvantageModule(config=cfg)
record = _make_record()
module.run_episode(record, staging)
rows = staging.read("advantage")
assert rows == []
def test_binarization_with_mock_values(staging: EpisodeStaging):
"""Advantage binarization produces positive/negative labels based on threshold."""
num_frames = 10
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1, -0.5, -0.6, -0.7, -0.8, -0.9], dtype=np.float32)
mock_values = np.array([-0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4], dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.0,
threshold_percentile=0.5,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
assert len(rows) == num_frames
# A_t = mc_returns - values
# advantages = [-0.1, 0.0, 0.1, 0.2, 0.3, -0.1, -0.2, -0.3, -0.4, -0.5]
# Median (50th pctile) = -0.1
# positive: advantage > -0.1 → indices 1,2,3,4
# negative: advantage <= -0.1 → indices 0,5,6,7,8,9
positives = [r for r in rows if r["content"] == "positive"]
negatives = [r for r in rows if r["content"] == "negative"]
assert len(positives) == 4
assert len(negatives) == 6
def test_intervention_frames_forced_positive(staging: EpisodeStaging):
"""Intervention frames are always scored as positive regardless of advantage value."""
num_frames = 5
mc_returns = np.array([-0.9, -0.9, -0.9, -0.9, -0.9], dtype=np.float32)
mock_values = np.array([-0.1, -0.1, -0.1, -0.1, -0.1], dtype=np.float32)
intervention = np.array([False, False, True, False, False])
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.0,
force_positive_on_intervention=True,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns, intervention_mask=intervention)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
# Frame 2 (intervention) should be positive despite negative advantage
assert rows[2]["content"] == "positive"
def test_dropout_reduces_output_rows(staging: EpisodeStaging):
"""Non-zero dropout rate omits some frames."""
num_frames = 100
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
mock_values = np.full(num_frames, -0.5, dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.3,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
# With 30% dropout on 100 frames, expect ~70 rows (with some variance)
assert 50 < len(rows) < 90
def test_staged_row_format(staging: EpisodeStaging):
"""Staged rows have the correct schema for language_persistent."""
num_frames = 5
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1], dtype=np.float32)
mock_values = np.full(5, -0.3, dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.0,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
for row in rows:
assert row["role"] == "user"
assert row["content"] in ("positive", "negative")
assert row["style"] == "advantage"
assert isinstance(row["timestamp"], float)
assert row["camera"] is None
assert row["tool_calls"] is None
def test_n_step_advantage():
"""N-step advantage uses partial returns + bootstrapped value."""
num_frames = 10
mc_returns = np.linspace(-0.9, 0.0, num_frames).astype(np.float32)
mock_values = np.full(num_frames, -0.45, dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
n_step=3,
dropout_rate=0.0,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with patch.object(module, "_ensure_model_loaded"):
advantages, _ = (
module.compute_advantages_for_episode.__wrapped__(module, record)
if hasattr(module.compute_advantages_for_episode, "__wrapped__")
else (None, None)
)
# Just verify computation works - use the internal method directly
module._model = MagicMock()
module._preprocessor = MagicMock()
with patch.object(module, "_compute_values", return_value=mock_values):
advantages, _ = module.compute_advantages_for_episode(record)
# For t where t+n < num_frames: A = mc_return[t] - mc_return[t+n] + values[t+n] - values[t]
# Since values are constant: A = mc_return[t] - mc_return[t+n]
# For t where t+n >= num_frames: A = mc_return[t] - values[t]
for t in range(num_frames):
if t + 3 < num_frames:
expected = mc_returns[t] - mc_returns[t + 3] + mock_values[t + 3] - mock_values[t]
else:
expected = mc_returns[t] - mock_values[t]
np.testing.assert_almost_equal(advantages[t], expected, decimal=5)
def test_compute_threshold():
"""Threshold is computed as configured percentile of non-intervention advantages."""
cfg = AdvantageConfig(threshold_percentile=0.3)
module = AdvantageModule(config=cfg)
advantages = np.array([-1.0, -0.5, 0.0, 0.5, 1.0], dtype=np.float32)
intervention_mask = np.array([False, False, False, False, False])
threshold = module._compute_threshold(advantages, intervention_mask)
expected = float(np.percentile(advantages, 30))
assert abs(threshold - expected) < 1e-6
def test_compute_threshold_excludes_intervention():
"""Threshold computation excludes intervention frames."""
cfg = AdvantageConfig(threshold_percentile=0.5)
module = AdvantageModule(config=cfg)
advantages = np.array([100.0, -1.0, 0.0, 1.0, 100.0], dtype=np.float32)
intervention_mask = np.array([True, False, False, False, True])
threshold = module._compute_threshold(advantages, intervention_mask)
# Only non-intervention: [-1.0, 0.0, 1.0], median = 0.0
expected = float(np.percentile([-1.0, 0.0, 1.0], 50))
assert abs(threshold - expected) < 1e-6
def test_missing_mc_return_raises():
"""Module raises if mc_return column is missing from dataset."""
import pandas as pd
cfg = AdvantageConfig(value_function_path="/fake/vf")
module = AdvantageModule(config=cfg)
module._model = MagicMock()
module._preprocessor = MagicMock()
record = EpisodeRecord(
episode_index=0,
episode_task="test",
frame_timestamps=(0.0, 0.1),
frame_indices=(0, 1),
data_path=Path("/fake/data.parquet"),
row_offset=0,
row_count=2,
)
record._frames_df_cache = pd.DataFrame({"episode_index": [0, 0], "frame_index": [0, 1]})
with pytest.raises(KeyError, match="mc_return"):
module.compute_advantages_for_episode(record)
@@ -30,6 +30,7 @@ pytest.importorskip("pandas", reason="pandas is required (install lerobot[datase
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
AdvantageConfig,
AnnotationPipelineConfig,
InterjectionsConfig,
PlanConfig,
@@ -37,6 +38,7 @@ from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
)
from lerobot.annotations.steerable_pipeline.executor import Executor # noqa: E402
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
AdvantageModule,
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
@@ -132,6 +134,7 @@ def _build_executor() -> Executor:
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
+145
View File
@@ -0,0 +1,145 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RECAP advantage conditioning recipes."""
from __future__ import annotations
from pathlib import Path
from lerobot.configs.recipe import load_recipe
from lerobot.datasets.language_render import render_sample
RECIPES_DIR = Path(__file__).resolve().parents[2] / "src" / "lerobot" / "configs" / "recipes"
def _persistent_rows(advantage: str | None = None):
"""Build minimal persistent rows with optional advantage."""
rows = [
{
"role": "user",
"content": "pick up the cup",
"style": "task_aug",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
},
{
"role": "assistant",
"content": "reaching for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
},
]
if advantage is not None:
rows.append(
{
"role": "user",
"content": advantage,
"style": "advantage",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
)
return rows
def test_recap_advantage_recipe_loads():
"""The recap_advantage.yaml recipe loads without errors."""
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
assert recipe.messages is not None
assert len(recipe.messages) == 3
assert recipe.bindings == {"advantage": "active_at(t, style=advantage)"}
def test_advantage_present_renders_indicator():
"""When advantage annotation exists, the prompt includes 'Advantage: positive'."""
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
result = render_sample(
recipe=recipe,
persistent=_persistent_rows(advantage="positive"),
events=[],
t=0.5,
sample_idx=0,
task="pick up the cup",
)
assert result is not None
messages = result["messages"]
assert len(messages) == 3
assert messages[1]["content"] == "Advantage: positive"
def test_advantage_negative_renders_indicator():
"""Negative advantage also appears in the prompt."""
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
result = render_sample(
recipe=recipe,
persistent=_persistent_rows(advantage="negative"),
events=[],
t=0.5,
sample_idx=0,
task="pick up the cup",
)
assert result is not None
messages = result["messages"]
assert messages[1]["content"] == "Advantage: negative"
def test_advantage_absent_skips_turn():
"""When no advantage annotation exists (dropout), the advantage turn is skipped."""
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
result = render_sample(
recipe=recipe,
persistent=_persistent_rows(advantage=None),
events=[],
t=0.5,
sample_idx=0,
task="pick up the cup",
)
assert result is not None
messages = result["messages"]
# Only task + subtask, no advantage turn
assert len(messages) == 2
assert messages[0]["content"] == "pick up the cup"
assert messages[1]["content"] == "reaching for the cup"
def test_advantage_absent_still_has_target():
"""Even without advantage, the target message (subtask) is preserved."""
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
result = render_sample(
recipe=recipe,
persistent=_persistent_rows(advantage=None),
events=[],
t=0.5,
sample_idx=0,
task="pick up the cup",
)
assert result is not None
assert result["target_message_indices"] == [1]
def test_blend_recipe_loads():
"""The blend recipe has two components with correct weights."""
recipe = load_recipe(RECIPES_DIR / "recap_advantage_blend.yaml")
assert recipe.blend is not None
assert "advantage_conditioned" in recipe.blend
assert "unconditional" in recipe.blend
assert recipe.blend["advantage_conditioned"].weight == 0.7
assert recipe.blend["unconditional"].weight == 0.3
+224
View File
@@ -0,0 +1,224 @@
#!/usr/bin/env python
"""Tests for PI05 Classifier-Free Guidance (CFG) inference."""
import pytest
pytest.importorskip("transformers", reason="transformers is required for PI05")
import torch # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi05 import PI05Config, make_pi05_pre_post_processors # noqa: E402
from lerobot.processor.converters import create_transition # noqa: E402
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402
from lerobot.types import TransitionKey # noqa: E402
from lerobot.utils.constants import ( # noqa: E402
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
OBS_LANGUAGE_UNCOND_TOKENS,
)
class TestRenderedMessagesToTaskBaseTaskPreservation:
"""Tests that RenderedMessagesToTaskStep preserves base_task for CFG."""
def test_preserves_string_base_task(self):
transition = create_transition(
complementary_data={
"task": "pick up the cup",
"messages": [
{"role": "user", "content": "pick up the cup, Advantage: positive"},
],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["base_task"] == "pick up the cup"
assert data["task"] == "pick up the cup, Advantage: positive"
def test_preserves_list_base_task(self):
transition = create_transition(
complementary_data={
"task": ["task1", "task2"],
"messages": [
{"role": "user", "content": "rendered with advantage"},
],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["base_task"] == ["task1", "task2"]
def test_no_base_task_when_messages_absent(self):
transition = create_transition(complementary_data={"task": "pick up the cup"})
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "base_task" not in data
class TestPi05PrepareStateTokenizerCfg:
"""Tests for Pi05PrepareStateTokenizerProcessorStep with cfg_enabled."""
def _make_transition(self, task, base_task=None):
complementary_data = {"task": task}
if base_task is not None:
complementary_data["base_task"] = base_task
return create_transition(
observation={"observation.state": torch.zeros(1, 14)},
complementary_data=complementary_data,
)
def test_cfg_disabled_no_uncond_task(self):
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=False)
transition = self._make_transition(task=["pick up the cup, Advantage: positive"])
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "uncond_task" not in data
def test_cfg_enabled_produces_uncond_task_from_base(self):
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
transition = self._make_transition(
task=["pick up the cup, Advantage: positive"],
base_task=["pick up the cup"],
)
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "uncond_task" in data
assert len(data["uncond_task"]) == 1
# Unconditional prompt uses base_task (no advantage)
assert "Advantage" not in data["uncond_task"][0]
assert "pick up the cup" in data["uncond_task"][0]
assert "State:" in data["uncond_task"][0]
def test_cfg_enabled_falls_back_to_task_when_no_base(self):
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
transition = self._make_transition(task=["pick up the cup"])
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
# Falls back to using task itself as unconditional
assert "uncond_task" in data
assert "pick up the cup" in data["uncond_task"][0]
class TestCfgPipelineConstruction:
"""Tests that the processor pipeline is constructed correctly for CFG."""
def _make_config(self, cfg_beta=1.0, recipe_path=None):
config = PI05Config(
max_action_dim=7,
max_state_dim=14,
cfg_beta=cfg_beta,
recipe_path=recipe_path,
device="cpu",
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
return config
def _make_dataset_stats(self):
return {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"min": torch.zeros(14),
"max": torch.ones(14),
"q01": torch.zeros(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"min": torch.zeros(7),
"max": torch.ones(7),
"q01": torch.zeros(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
}
def test_no_uncond_tokenizer_when_cfg_disabled(self):
from lerobot.processor import TokenizerProcessorStep
config = self._make_config(cfg_beta=1.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
assert len(tokenizer_steps) == 1
def test_uncond_tokenizer_added_when_cfg_enabled(self):
from lerobot.processor import TokenizerProcessorStep
config = self._make_config(cfg_beta=2.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
assert len(tokenizer_steps) == 2
uncond_tokenizer = tokenizer_steps[1]
assert uncond_tokenizer.task_key == "uncond_task"
assert uncond_tokenizer.output_tokens_key == OBS_LANGUAGE_UNCOND_TOKENS
assert uncond_tokenizer.output_mask_key == OBS_LANGUAGE_UNCOND_ATTENTION_MASK
def test_cfg_pipeline_produces_both_token_sets(self):
config = self._make_config(cfg_beta=2.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
batch = {
"observation.state": torch.randn(14),
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_ATTENTION_MASK in processed
assert OBS_LANGUAGE_UNCOND_TOKENS in processed
assert OBS_LANGUAGE_UNCOND_ATTENTION_MASK in processed
# Both should be tensors with the same shape
assert processed[OBS_LANGUAGE_TOKENS].shape == processed[OBS_LANGUAGE_UNCOND_TOKENS].shape
assert (
processed[OBS_LANGUAGE_ATTENTION_MASK].shape
== processed[OBS_LANGUAGE_UNCOND_ATTENTION_MASK].shape
)
def test_cfg_beta_1_no_uncond_tokens_in_output(self):
config = self._make_config(cfg_beta=1.0)
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
batch = {
"observation.state": torch.randn(14),
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_UNCOND_TOKENS not in processed
@@ -0,0 +1,186 @@
#!/usr/bin/env python
"""Tests for RenderedMessagesToTaskStep and PI05 pipeline integration with advantage."""
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.processor.converters import create_transition # noqa: E402
from lerobot.processor.render_messages_processor import RenderMessagesStep # noqa: E402
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402
from lerobot.types import TransitionKey # noqa: E402
def test_rendered_messages_to_task_noops_without_messages():
"""Without messages key, the step is a no-op."""
transition = create_transition(complementary_data={"task": "pick up the cup"})
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["task"] == "pick up the cup"
def test_rendered_messages_to_task_extracts_user_content():
"""Extracts user-role message content and joins with newline."""
transition = create_transition(
complementary_data={
"task": "original task",
"messages": [
{"role": "user", "content": "pick up the cup"},
{"role": "user", "content": "Advantage: positive"},
{"role": "assistant", "content": "reach for cup"},
],
"message_streams": ["high_level", "high_level", "low_level"],
"target_message_indices": [2],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["task"] == "pick up the cup\nAdvantage: positive"
assert "messages" not in data
assert "message_streams" not in data
assert "target_message_indices" not in data
def test_rendered_messages_to_task_handles_multimodal_blocks():
"""Extracts text from HF multimodal content blocks."""
transition = create_transition(
complementary_data={
"task": "original",
"messages": [
{
"role": "user",
"content": [
{"type": "image", "image": "placeholder"},
{"type": "text", "text": "describe this"},
],
},
{"role": "assistant", "content": "a cup on a table"},
],
"message_streams": ["high_level", "low_level"],
"target_message_indices": [1],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["task"] == "describe this"
def test_rendered_messages_to_task_preserves_list_task_format():
"""When original task is a list (batched), output is also a list."""
transition = create_transition(
complementary_data={
"task": ["task1", "task2"],
"messages": [
{"role": "user", "content": "rendered task"},
{"role": "assistant", "content": "do it", "target": True},
],
"message_streams": ["high_level", "low_level"],
"target_message_indices": [1],
}
)
step = RenderedMessagesToTaskStep()
out = step(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["task"] == ["rendered task", "rendered task"]
def test_full_render_then_flatten_pipeline():
"""RenderMessagesStep + RenderedMessagesToTaskStep produces correct task string."""
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(
role="user",
content="Advantage: ${advantage}",
stream="high_level",
if_present="advantage",
),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(
complementary_data={
"task": "pick up the cup",
"timestamp": torch.tensor(0.5),
"index": torch.tensor(0),
"language_persistent": [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
},
{
"role": "user",
"content": "positive",
"style": "advantage",
"timestamp": 0.1,
"camera": None,
"tool_calls": None,
},
],
"language_events": [],
}
)
# Step 1: Render recipe
rendered = RenderMessagesStep(recipe=recipe)(transition)
# Step 2: Flatten to task string
out = RenderedMessagesToTaskStep()(rendered)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "pick up the cup" in data["task"]
assert "Advantage: positive" in data["task"]
def test_full_render_advantage_absent_skips_turn():
"""When advantage row is absent, the advantage turn is skipped via if_present."""
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(
role="user",
content="Advantage: ${advantage}",
stream="high_level",
if_present="advantage",
),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(
complementary_data={
"task": "pick up the cup",
"timestamp": torch.tensor(0.5),
"index": torch.tensor(0),
"language_persistent": [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
},
],
"language_events": [],
}
)
rendered = RenderMessagesStep(recipe=recipe)(transition)
out = RenderedMessagesToTaskStep()(rendered)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["task"] == "pick up the cup"
assert "Advantage" not in data["task"]
@@ -0,0 +1,518 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RECAP's distributional value function."""
from __future__ import annotations
import pytest
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_IMAGES
from tests.utils import skip_if_package_missing
BATCH_SIZE = 4
NUM_BINS = 201
IMAGE_KEY = f"{OBS_IMAGES}.top"
def _make_config(**overrides) -> DistributionalVFConfig:
defaults = {
"init_from_actor_path": "",
"device": "cpu",
"image_resolution": (224, 224),
}
defaults.update(overrides)
config = DistributionalVFConfig(**defaults)
config.input_features = {
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
}
return config
def _make_model():
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel(_make_config())
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
return {
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
"mc_return": torch.rand(batch_size, device=device) * -1.0,
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
}
def test_config_registered_in_reward_model_registry():
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
known = RewardModelConfig.get_known_choices()
assert "distributional_value_function" in known
def test_factory_returns_correct_class():
"""get_reward_model_class returns DistributionalVFRewardModel."""
from lerobot.rewards.factory import get_reward_model_class
cls = get_reward_model_class("distributional_value_function")
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
assert cls is DistributionalVFRewardModel
def test_make_reward_model_config_factory():
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
from lerobot.rewards.factory import make_reward_model_config
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
assert isinstance(config, DistributionalVFConfig)
assert config.num_value_bins == 101
@skip_if_package_missing("transformers")
def test_hl_gauss_sums_to_one():
"""HL-Gauss target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
dist = model.hl_gauss_target(targets)
assert dist.shape == (4, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_non_negative():
"""HL-Gauss target probabilities are all non-negative."""
model = _make_model()
targets = torch.linspace(-1.0, 0.0, 10)
dist = model.hl_gauss_target(targets)
assert (dist >= 0).all()
@skip_if_package_missing("transformers")
def test_hl_gauss_expected_value_matches():
"""E[V] under HL-Gauss distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.hl_gauss_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_handles_2d_input():
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
dist = model.hl_gauss_target(targets)
assert dist.shape == (2, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_sums_to_one():
"""Dirac delta target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
dist = model.dirac_delta_target(targets)
assert dist.shape == (5, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_at_most_two_nonzero():
"""Dirac delta places probability on at most two adjacent bins."""
model = _make_model()
targets = torch.tensor([-0.7523, -0.0013])
dist = model.dirac_delta_target(targets)
for i in range(2):
assert (dist[i] > 0).sum() <= 2
@skip_if_package_missing("transformers")
def test_dirac_delta_expected_value_matches():
"""E[V] under Dirac delta distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.dirac_delta_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_boundary_values_clamped():
"""Values outside support are clamped to boundary bins."""
model = _make_model()
targets = torch.tensor([-1.5, 0.5])
dist = model.dirac_delta_target(targets)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
assert dist[0, 0] == 1.0
assert dist[1, -1] == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_single_nonzero():
"""One-hot target has exactly one non-zero bin per sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
dist = model.one_hot_target(targets)
assert dist.shape == (4, NUM_BINS)
for i in range(4):
assert (dist[i] > 0).sum() == 1
assert dist[i].sum() == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_nearest_bin():
"""One-hot target activates the bin closest to the target value."""
model = _make_model()
targets = torch.tensor([-0.5])
dist = model.one_hot_target(targets)
hot_idx = dist[0].argmax()
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
@skip_if_package_missing("transformers")
def test_terminal_gets_one_hot():
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
is_terminal = torch.tensor([False, True, False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
)
for i in range(4):
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
assert (dist[1] > 0).sum() == 1
assert (dist[3] > 0).sum() == 1
assert (dist[0] > 0).sum() > 2
assert (dist[2] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_no_terminal_override_when_disabled():
"""When use_one_hot_terminal=False, terminal states use the base method."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3])
is_terminal = torch.tensor([False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
)
assert (dist[1] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_model_has_expected_components():
"""Model scaffold contains all architectural components."""
model = _make_model()
assert hasattr(model, "vision_tower")
assert hasattr(model, "multi_modal_projector")
assert hasattr(model, "token_embedding")
assert hasattr(model, "layers")
assert hasattr(model, "value_head")
assert hasattr(model, "cls_embedding")
assert hasattr(model, "norm")
assert hasattr(model, "rotary_emb")
assert hasattr(model, "bin_centers")
@skip_if_package_missing("transformers")
def test_model_bin_centers_shape():
"""Bin centers buffer has shape (num_value_bins,)."""
model = _make_model()
assert model.bin_centers.shape == (NUM_BINS,)
@skip_if_package_missing("transformers")
def test_model_layer_count():
"""Transformer has num_hidden_layers (6) layers."""
model = _make_model()
assert len(model.layers) == 6
@skip_if_package_missing("transformers")
def test_model_value_head_output_dim():
"""Value head outputs num_value_bins logits."""
model = _make_model()
assert model.value_head.out_features == NUM_BINS
@skip_if_package_missing("transformers")
def test_forward_returns_loss_and_dict():
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
model = _make_model()
batch = _make_batch()
loss, output_dict = model.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert "loss" in output_dict
assert "predicted_value_mean" in output_dict
assert "mc_return_mean" in output_dict
@skip_if_package_missing("transformers")
def test_forward_loss_is_positive():
"""Cross-entropy loss is strictly positive for random weights."""
model = _make_model()
batch = _make_batch()
loss, _ = model.forward(batch)
assert loss.item() > 0
@skip_if_package_missing("transformers")
def test_compute_reward_returns_correct_shape():
"""compute_reward returns [batch_size] tensor of finite float32 values."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=3)
with torch.no_grad():
values = model.compute_reward(batch)
assert values.shape == (3,)
assert values.dtype == torch.float32
assert torch.isfinite(values).all()
@skip_if_package_missing("transformers")
def test_compute_reward_values_in_support_range():
"""Predicted values lie within [value_support_min, value_support_max]."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=8)
with torch.no_grad():
values = model.compute_reward(batch)
assert (values >= -1.0 - 0.01).all()
assert (values <= 0.0 + 0.01).all()
@skip_if_package_missing("transformers")
def test_processor_pipeline_produces_expected_keys():
"""Full preprocessor pipeline produces tokenized text and processed images."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
config = _make_config()
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
raw_batch = {
IMAGE_KEY: torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(raw_batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_ATTENTION_MASK in processed
assert IMAGE_KEY in processed
@skip_if_package_missing("transformers")
def test_gradient_flows_through_value_head():
"""Backprop produces non-zero gradients on the value head."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.value_head.weight.grad is not None
assert not torch.all(model.value_head.weight.grad == 0)
@skip_if_package_missing("transformers")
def test_gradient_flows_through_cls_embedding():
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.cls_embedding.grad is not None
assert not torch.all(model.cls_embedding.grad == 0)
def test_config_requires_visual_feature():
"""validate_features raises if no VISUAL feature is present."""
config = DistributionalVFConfig(init_from_actor_path="")
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
}
with pytest.raises(ValueError, match="VISUAL"):
config.validate_features()
def test_config_passes_with_visual_feature():
"""validate_features succeeds when a VISUAL feature is present."""
config = _make_config()
config.validate_features()
@skip_if_package_missing("transformers")
def test_save_load_pretrained_roundtrip(tmp_path):
"""Saved model can be loaded back with identical weights."""
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
model = _make_model()
model._save_pretrained(tmp_path)
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
orig_sd = model.state_dict()
loaded_sd = loaded.state_dict()
assert set(orig_sd.keys()) == set(loaded_sd.keys())
for key in orig_sd:
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
@skip_if_package_missing("transformers")
def test_image_preprocessor_normalizes_to_minus_one_one():
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 224, 224, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.min() >= -1.0 - 1e-5
assert image.max() <= 1.0 + 1e-5
@skip_if_package_missing("transformers")
def test_image_preprocessor_resizes_with_pad():
"""Image preprocessor resizes non-square images to target resolution."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 480, 640, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.shape[1:3] == (224, 224)
def test_task_prompt_formats_correctly():
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: pick up the cup."
def test_task_prompt_handles_string_input():
"""Task prompt step accepts a plain string (not just a list)."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: open drawer."
def test_task_prompt_raises_on_missing_task():
"""Task prompt step raises ValueError when task key is absent."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {},
}
with pytest.raises(ValueError, match="No task found"):
step(transition)
+514
View File
@@ -0,0 +1,514 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lerobot-compute-returns script."""
import json
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from lerobot.scripts.lerobot_compute_returns import (
IS_TERMINAL_COL,
MC_RETURN_COL,
ComputeReturnsConfig,
_get_episode_success,
compute_episode_returns,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def parquet_dataset(tmp_path):
"""Build a minimal parquet shard + info.json for testing I/O logic.
Mirrors the lerobot-rollout DAgger convention: ``next.success`` is False
on all frames except the terminal frame of successful episodes.
Even episodes are successful, odd episodes are failures.
"""
num_episodes = 3
frames_per_ep = 10
root = tmp_path / "test_dataset"
data_dir = root / "data" / "chunk-000"
meta_dir = root / "meta"
data_dir.mkdir(parents=True)
meta_dir.mkdir(parents=True)
all_rows = []
episodes_meta = []
global_idx = 0
for ep in range(num_episodes):
ep_from = global_idx
is_successful = ep % 2 == 0
for frame in range(frames_per_ep):
is_last_frame = frame == frames_per_ep - 1
all_rows.append(
{
"episode_index": ep,
"frame_index": frame,
"index": global_idx,
"next.success": is_successful and is_last_frame,
}
)
global_idx += 1
ep_to = global_idx
episodes_meta.append(
{
"episode_index": ep,
"length": frames_per_ep,
"dataset_from_index": ep_from,
"dataset_to_index": ep_to,
}
)
table = pa.table(
{
"episode_index": [r["episode_index"] for r in all_rows],
"frame_index": [r["frame_index"] for r in all_rows],
"index": [r["index"] for r in all_rows],
"next.success": [r["next.success"] for r in all_rows],
}
)
parquet_path = data_dir / "episode_000000.parquet"
pq.write_table(table, parquet_path)
info = {
"codebase_version": "v3.0",
"total_episodes": num_episodes,
"total_frames": global_idx,
"fps": 30,
"features": {
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
"index": {"dtype": "int64", "shape": [1], "names": None},
"next.success": {"dtype": "bool", "shape": [1], "names": None},
},
}
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
return root, parquet_path, episodes_meta
def _rewrite_shard(parquet_path: Path, episodes_meta: list[dict], config: ComputeReturnsConfig):
"""Rewrite a single parquet shard using the core logic from compute_returns."""
table = pq.read_table(parquet_path)
if not config.force and IS_TERMINAL_COL in table.column_names:
return
all_is_terminal = np.zeros(len(table), dtype=bool)
all_mc_return = np.zeros(len(table), dtype=np.float32)
episode_col = table.column("episode_index").to_pylist()
for ep_info in episodes_meta:
ep_idx = ep_info["episode_index"]
ep_len = ep_info["length"]
mask = np.array([v == ep_idx for v in episode_col], dtype=bool)
local_indices = np.where(mask)[0]
ep_subtable = table.filter(mask)
success = _get_episode_success(ep_subtable, config.success_key, config.default_success)
is_terminal, mc_return = compute_episode_returns(
num_frames=ep_len,
success=success,
c_fail=config.c_fail,
gamma=config.gamma,
max_episode_length=config.max_episode_length or ep_len,
)
all_is_terminal[local_indices] = is_terminal
all_mc_return[local_indices] = mc_return
if IS_TERMINAL_COL in table.column_names:
table = table.drop(IS_TERMINAL_COL)
if MC_RETURN_COL in table.column_names:
table = table.drop(MC_RETURN_COL)
table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal))
table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return))
pq.write_table(table, parquet_path)
# ---------------------------------------------------------------------------
# Tests: compute_episode_returns (pure math, no I/O)
# ---------------------------------------------------------------------------
def test_successful_episode_terminal_reward_is_zero():
"""Terminal MC return for a successful episode should be 0."""
_, mc_return = compute_episode_returns(
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10
)
assert mc_return[-1] == pytest.approx(0.0, abs=1e-6)
def test_failed_episode_terminal_reward_reflects_cfail():
"""Terminal MC return for a failed episode should be -C_fail / H."""
horizon = 100
c_fail = 50.0
_, mc_return = compute_episode_returns(
num_frames=10, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
)
assert mc_return[-1] == pytest.approx(-c_fail / horizon, abs=1e-5)
def test_is_terminal_only_last_frame():
"""Only the last frame of an episode should be marked terminal."""
is_terminal, _ = compute_episode_returns(
num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20
)
assert is_terminal[-1] == True # noqa: E712
assert not any(is_terminal[:-1])
def test_mc_return_monotonically_increases_for_success():
"""For a successful undiscounted episode, returns should increase toward 0."""
_, mc_return = compute_episode_returns(
num_frames=50, success=True, c_fail=50.0, gamma=1.0, max_episode_length=50
)
for i in range(len(mc_return) - 1):
assert mc_return[i] <= mc_return[i + 1]
def test_mc_return_bounded_negative_to_zero():
"""MC returns for successful episodes should be in (-1, 0]."""
_, mc_return = compute_episode_returns(
num_frames=100, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100
)
assert mc_return[-1] == pytest.approx(0.0, abs=1e-6)
assert all(v <= 0.0 for v in mc_return)
assert all(v >= -1.0 - 1e-6 for v in mc_return)
def test_first_frame_return_success():
"""First frame return for successful episode equals -(N-1)/H."""
num_frames = 10
horizon = 10
_, mc_return = compute_episode_returns(
num_frames=num_frames, success=True, c_fail=50.0, gamma=1.0, max_episode_length=horizon
)
expected = -(num_frames - 1) / horizon
assert mc_return[0] == pytest.approx(expected, abs=1e-5)
def test_first_frame_return_failure():
"""First frame return for failed episode includes the failure penalty."""
num_frames = 10
horizon = 100
c_fail = 50.0
_, mc_return = compute_episode_returns(
num_frames=num_frames, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
)
expected = (-(num_frames - 1) / horizon) + (-c_fail / horizon)
assert mc_return[0] == pytest.approx(expected, abs=1e-5)
def test_discount_factor_less_than_one():
"""Discount factor < 1 should make earlier frames have smaller magnitude."""
_, mc_undiscounted = compute_episode_returns(
num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20
)
_, mc_discounted = compute_episode_returns(
num_frames=20, success=True, c_fail=50.0, gamma=0.99, max_episode_length=20
)
assert abs(mc_discounted[0]) < abs(mc_undiscounted[0])
def test_single_frame_episode_success():
"""Single-frame successful episode: return should be 0."""
is_terminal, mc_return = compute_episode_returns(
num_frames=1, success=True, c_fail=50.0, gamma=1.0, max_episode_length=1
)
assert mc_return[0] == pytest.approx(0.0, abs=1e-6)
assert is_terminal[0] == True # noqa: E712
def test_single_frame_episode_failure():
"""Single-frame failed episode: return should be -C_fail/H."""
horizon = 100
c_fail = 50.0
is_terminal, mc_return = compute_episode_returns(
num_frames=1, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
)
assert mc_return[0] == pytest.approx(-c_fail / horizon, abs=1e-5)
assert is_terminal[0] == True # noqa: E712
def test_horizon_normalization_scales_returns():
"""Larger horizon should scale down the per-step penalty."""
_, mc_small_h = compute_episode_returns(
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10
)
_, mc_large_h = compute_episode_returns(
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100
)
assert abs(mc_large_h[0]) < abs(mc_small_h[0])
# ---------------------------------------------------------------------------
# Tests: _get_episode_success (in-memory PyArrow tables)
# ---------------------------------------------------------------------------
def test_default_success_overrides_column():
"""default_success should override any column value."""
table = pa.table({"next.success": [True, True, True]})
assert _get_episode_success(table, "next.success", default_success=False) is False
def test_reads_bool_column():
"""Should detect success via any() reduction over the column."""
table_success = pa.table({"next.success": [False, False, True]})
table_fail = pa.table({"next.success": [False, False, False]})
assert _get_episode_success(table_success, "next.success", None) is True
assert _get_episode_success(table_fail, "next.success", None) is False
def test_reads_int_column():
"""Should interpret integer success column (0/1) as bool via any()."""
table = pa.table({"task_success": [0, 0, 1]})
assert _get_episode_success(table, "task_success", None) is True
def test_all_zeros_means_failure():
"""An episode with all-zero success values is a failure."""
table = pa.table({"next.success": [0, 0, 0]})
assert _get_episode_success(table, "next.success", None) is False
def test_missing_column_defaults_to_true():
"""When success column is missing, assume success (demo data)."""
table = pa.table({"frame_index": [0, 1, 2]})
assert _get_episode_success(table, "next.success", None) is True
# ---------------------------------------------------------------------------
# Tests: parquet rewriting (integration, writes to disk)
# ---------------------------------------------------------------------------
def test_writes_columns_to_parquet(parquet_dataset):
"""The rewrite logic should add is_terminal and mc_return columns."""
root, parquet_path, episodes_meta = parquet_dataset
table_before = pq.read_table(parquet_path)
assert IS_TERMINAL_COL not in table_before.column_names
assert MC_RETURN_COL not in table_before.column_names
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table_after = pq.read_table(parquet_path)
assert IS_TERMINAL_COL in table_after.column_names
assert MC_RETURN_COL in table_after.column_names
def test_terminal_frames_correct(parquet_dataset):
"""Only the last frame of each episode should be terminal."""
root, parquet_path, episodes_meta = parquet_dataset
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
is_terminal = table.column(IS_TERMINAL_COL).to_pylist()
terminal_indices = [i for i, v in enumerate(is_terminal) if v]
assert terminal_indices == [9, 19, 29]
def test_success_episodes_return_zero_at_terminal(tmp_path):
"""Successful episodes (ep 0) should have mc_return=0 at terminal."""
num_episodes = 2
frames_per_ep = 5
root = tmp_path / "test_dataset"
data_dir = root / "data" / "chunk-000"
meta_dir = root / "meta"
data_dir.mkdir(parents=True)
meta_dir.mkdir(parents=True)
all_rows = []
episodes_meta = []
global_idx = 0
for ep in range(num_episodes):
ep_from = global_idx
is_successful = ep % 2 == 0
for frame in range(frames_per_ep):
is_last_frame = frame == frames_per_ep - 1
all_rows.append(
{
"episode_index": ep,
"frame_index": frame,
"index": global_idx,
"next.success": is_successful and is_last_frame,
}
)
global_idx += 1
episodes_meta.append(
{
"episode_index": ep,
"length": frames_per_ep,
"dataset_from_index": ep_from,
"dataset_to_index": global_idx,
}
)
table = pa.table(
{
"episode_index": [r["episode_index"] for r in all_rows],
"frame_index": [r["frame_index"] for r in all_rows],
"index": [r["index"] for r in all_rows],
"next.success": [r["next.success"] for r in all_rows],
}
)
parquet_path = data_dir / "episode_000000.parquet"
pq.write_table(table, parquet_path)
info = {
"codebase_version": "v3.0",
"total_episodes": num_episodes,
"total_frames": global_idx,
"fps": 30,
"features": {
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
"index": {"dtype": "int64", "shape": [1], "names": None},
"next.success": {"dtype": "bool", "shape": [1], "names": None},
},
}
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
mc_return = table.column(MC_RETURN_COL).to_pylist()
assert mc_return[4] == pytest.approx(0.0, abs=1e-5)
def test_failed_episodes_have_negative_terminal(tmp_path):
"""Failed episodes (ep 1) should have mc_return < 0 at terminal."""
num_episodes = 2
frames_per_ep = 5
root = tmp_path / "test_dataset"
data_dir = root / "data" / "chunk-000"
meta_dir = root / "meta"
data_dir.mkdir(parents=True)
meta_dir.mkdir(parents=True)
all_rows = []
episodes_meta = []
global_idx = 0
for ep in range(num_episodes):
ep_from = global_idx
is_successful = ep % 2 == 0
for frame in range(frames_per_ep):
is_last_frame = frame == frames_per_ep - 1
all_rows.append(
{
"episode_index": ep,
"frame_index": frame,
"index": global_idx,
"next.success": is_successful and is_last_frame,
}
)
global_idx += 1
episodes_meta.append(
{
"episode_index": ep,
"length": frames_per_ep,
"dataset_from_index": ep_from,
"dataset_to_index": global_idx,
}
)
table = pa.table(
{
"episode_index": [r["episode_index"] for r in all_rows],
"frame_index": [r["frame_index"] for r in all_rows],
"index": [r["index"] for r in all_rows],
"next.success": [r["next.success"] for r in all_rows],
}
)
parquet_path = data_dir / "episode_000000.parquet"
pq.write_table(table, parquet_path)
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, c_fail=50.0, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
mc_return = table.column(MC_RETURN_COL).to_pylist()
assert mc_return[9] < 0.0
def test_idempotent_with_force_flag(parquet_dataset):
"""Running twice with force should produce identical results."""
root, parquet_path, episodes_meta = parquet_dataset
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table1 = pq.read_table(parquet_path)
mc1 = table1.column(MC_RETURN_COL).to_pylist()
_rewrite_shard(parquet_path, episodes_meta, config)
table2 = pq.read_table(parquet_path)
mc2 = table2.column(MC_RETURN_COL).to_pylist()
assert mc1 == mc2
def test_skips_if_columns_exist_without_force(parquet_dataset):
"""Without force, existing columns should not be overwritten."""
root, parquet_path, episodes_meta = parquet_dataset
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
original_mc = table.column(MC_RETURN_COL).to_pylist()
config_no_force = ComputeReturnsConfig(success_key="next.success", max_episode_length=20, force=False)
_rewrite_shard(parquet_path, episodes_meta, config_no_force)
table2 = pq.read_table(parquet_path)
assert table2.column(MC_RETURN_COL).to_pylist() == original_mc
def test_updates_info_json(parquet_dataset):
"""info.json should be updated with is_terminal and mc_return features."""
from lerobot.scripts.lerobot_compute_returns import _update_info_json
root, parquet_path, episodes_meta = parquet_dataset
_update_info_json(root, None)
info_path = root / "meta" / "info.json"
info = json.loads(info_path.read_text())
assert IS_TERMINAL_COL in info["features"]
assert MC_RETURN_COL in info["features"]
assert info["features"][IS_TERMINAL_COL]["dtype"] == "bool"
assert info["features"][MC_RETURN_COL]["dtype"] == "float32"
+97
View File
@@ -338,6 +338,103 @@ def test_dagger_events_reset():
assert not events.upload_requested.is_set()
def test_dagger_mark_success():
"""mark_success sets the episode label to True."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
assert events.consume_episode_success() is None
events.mark_success()
assert events.consume_episode_success() is True
# Consuming clears the label
assert events.consume_episode_success() is None
def test_dagger_mark_failure():
"""mark_failure sets the episode label to False."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
events.mark_failure()
assert events.consume_episode_success() is False
def test_dagger_success_overrides_failure():
"""Last label wins — success after failure overrides."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
events.mark_failure()
events.mark_success()
assert events.consume_episode_success() is True
def test_dagger_reset_clears_success_label():
"""reset() clears any pending episode success label."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
events.mark_success()
events.reset()
assert events.consume_episode_success() is None
def test_stamp_episode_success_labels_terminal_frame():
"""_stamp_episode_success sets last frame's next.success to True."""
import numpy as np
from lerobot.rollout.strategies.dagger import DAggerStrategy
strategy = DAggerStrategy.__new__(DAggerStrategy)
strategy.config = MagicMock()
from lerobot.rollout.strategies import DAggerEvents
strategy._events = DAggerEvents()
strategy._events.mark_success()
dataset = MagicMock()
dataset.writer.episode_buffer = {
"next.success": [
np.array([False], dtype=bool),
np.array([False], dtype=bool),
np.array([False], dtype=bool),
],
}
strategy._stamp_episode_success(dataset)
assert dataset.writer.episode_buffer["next.success"][-1].item() is True
assert dataset.writer.episode_buffer["next.success"][0].item() is False
def test_stamp_episode_success_no_label_stays_false():
"""Without a label, all frames remain False."""
import numpy as np
from lerobot.rollout.strategies.dagger import DAggerStrategy
strategy = DAggerStrategy.__new__(DAggerStrategy)
strategy.config = MagicMock()
from lerobot.rollout.strategies import DAggerEvents
strategy._events = DAggerEvents()
dataset = MagicMock()
dataset.writer.episode_buffer = {
"next.success": [
np.array([False], dtype=bool),
np.array([False], dtype=bool),
],
}
strategy._stamp_episode_success(dataset)
assert all(v.item() is False for v in dataset.writer.episode_buffer["next.success"])
# ---------------------------------------------------------------------------
# Context dataclass
# ---------------------------------------------------------------------------
Generated
+348 -343
View File
File diff suppressed because it is too large Load Diff