mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2d4be80425 | |||
| 7d1e1b0357 | |||
| 0d2ba54385 | |||
| 4b779b1e99 | |||
| ea908c0672 | |||
| e5c94c732f | |||
| c18b8277f1 | |||
| fa3eb9fce3 | |||
| 500c91ba92 |
@@ -220,6 +220,7 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
recap = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
@@ -317,6 +318,7 @@ all = [
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[recap]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -340,6 +342,7 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
lerobot-compute-returns="lerobot.scripts.lerobot_compute_returns:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
|
||||
@@ -169,6 +169,43 @@ class ExecutorConfig:
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdvantageConfig:
|
||||
"""``advantage`` module: RECAP advantage scoring via frozen value function."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Path or Hub repo ID of the trained distributional value function checkpoint.
|
||||
value_function_path: str = ""
|
||||
|
||||
# Device to run the value function on.
|
||||
device: str = "cuda"
|
||||
|
||||
# N-step lookahead for advantage estimation.
|
||||
# None = MC (N=T): A_t = R_t - V(s_t), using mc_return from dataset.
|
||||
# 50 = fine-tuning mode: A_t = Σ r_{t:t+N} + V(s_{t+N}) - V(s_t).
|
||||
n_step: int | None = None
|
||||
|
||||
# Per-task percentile for binarization threshold ε_ℓ.
|
||||
# Actions with advantage > ε_ℓ get I_t = True (positive).
|
||||
threshold_percentile: float = 0.3
|
||||
|
||||
# Fraction of frames to randomly omit advantage labels (enables CFG).
|
||||
dropout_rate: float = 0.3
|
||||
|
||||
# Force I_t = True for frames marked as human interventions.
|
||||
force_positive_on_intervention: bool = True
|
||||
|
||||
# Column name in dataset for intervention flag.
|
||||
intervention_key: str = "intervention"
|
||||
|
||||
# Column name for pre-computed MC returns (from lerobot-compute-returns).
|
||||
mc_return_key: str = "mc_return"
|
||||
|
||||
# Batch size for value function inference.
|
||||
batch_size: int = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
@@ -190,6 +227,7 @@ class AnnotationPipelineConfig:
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
advantage: AdvantageConfig = field(default_factory=AdvantageConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
@@ -15,20 +15,24 @@
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
The executor runs **seven phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
phase 5: ``advantage`` module (advantage scoring via frozen VF)
|
||||
phase 6: validator
|
||||
phase 7: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Phase 5 (advantage) does not depend on the VLM modules, it uses a frozen
|
||||
distributional value function to compute per-frame advantage indicators.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
@@ -74,7 +78,7 @@ class PipelineRunSummary:
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
"""Run all seven phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
@@ -86,6 +90,7 @@ class Executor:
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
advantage: Any # AdvantageModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
@@ -112,6 +117,8 @@ class Executor:
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
# Phase 5: ``advantage`` module (advantage scoring via frozen VF)
|
||||
phases.append(self._run_module_phase("advantage", records, staging_dir, self.advantage))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .advantage import AdvantageModule
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"AdvantageModule",
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Advantage scoring module for RECAP.
|
||||
|
||||
Computes per-frame advantage values using a frozen distributional value function,
|
||||
binarizes them into improvement indicators (I_t), and emits ``style="advantage"``
|
||||
persistent rows for policy conditioning.
|
||||
|
||||
Paper reference: pi*0.6, Section IV-B and Appendix F.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..config import AdvantageConfig
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdvantageModule:
|
||||
"""Compute advantage indicators and emit persistent annotation rows.
|
||||
|
||||
The module loads a frozen distributional value function and scores each
|
||||
frame in an episode. Advantages are binarized into ``positive``/``negative``
|
||||
indicators using a per-task threshold, then written as ``style="advantage"``
|
||||
persistent rows into the staging area.
|
||||
|
||||
Requires ``mc_return`` column in the dataset (from lerobot-compute-returns).
|
||||
"""
|
||||
|
||||
config: AdvantageConfig
|
||||
_model: Any = field(default=None, init=False, repr=False)
|
||||
_preprocessor: Any = field(default=None, init=False, repr=False)
|
||||
_threshold: float | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def _ensure_model_loaded(self) -> None:
|
||||
"""Lazy-load the frozen value function on first use."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from lerobot.rewards import (
|
||||
make_reward_model,
|
||||
make_reward_model_config,
|
||||
make_reward_pre_post_processors,
|
||||
)
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"distributional_value_function",
|
||||
pretrained_path=self.config.value_function_path,
|
||||
device=self.config.device,
|
||||
)
|
||||
self._model = make_reward_model(cfg)
|
||||
self._model.eval()
|
||||
for p in self._model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self._preprocessor, _ = make_reward_pre_post_processors(cfg)
|
||||
logger.info("Loaded frozen VF from %s on %s", self.config.value_function_path, self.config.device)
|
||||
|
||||
def compute_advantages_for_episode(self, record: EpisodeRecord) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute raw advantage values for all frames in an episode.
|
||||
|
||||
Returns:
|
||||
(advantages, intervention_mask) both shape [num_frames].
|
||||
advantages[t] = A_t, intervention_mask[t] = True if frame is intervention.
|
||||
"""
|
||||
self._ensure_model_loaded()
|
||||
|
||||
df = record.frames_df()
|
||||
num_frames = len(df)
|
||||
|
||||
mc_return_key = self.config.mc_return_key
|
||||
if mc_return_key not in df.columns:
|
||||
raise KeyError(
|
||||
f"Column '{mc_return_key}' not found in episode {record.episode_index}. "
|
||||
"Run lerobot-compute-returns first."
|
||||
)
|
||||
|
||||
mc_returns = df[mc_return_key].values.astype(np.float32)
|
||||
|
||||
intervention_mask = np.zeros(num_frames, dtype=bool)
|
||||
if self.config.intervention_key in df.columns:
|
||||
intervention_mask = df[self.config.intervention_key].values.astype(bool)
|
||||
|
||||
# Skip VF inference on intervention frames — they're always "positive"
|
||||
# regardless of advantage value, so V(s_t) is never used for them.
|
||||
skip_mask = intervention_mask if self.config.force_positive_on_intervention else None
|
||||
values = self._compute_values(record, skip_mask=skip_mask)
|
||||
|
||||
if self.config.n_step is None:
|
||||
advantages = mc_returns - values
|
||||
else:
|
||||
advantages = self._compute_n_step_advantages(mc_returns, values, record, n=self.config.n_step)
|
||||
|
||||
return advantages, intervention_mask
|
||||
|
||||
def _compute_values(self, record: EpisodeRecord, skip_mask: np.ndarray | None = None) -> np.ndarray:
|
||||
"""Run frozen VF over all frames to get V(s_t) predictions.
|
||||
|
||||
Args:
|
||||
record: Episode data.
|
||||
skip_mask: Optional boolean mask [num_frames]. Frames where True are
|
||||
skipped (left as 0.0) to avoid unnecessary inference.
|
||||
"""
|
||||
df = record.frames_df()
|
||||
num_frames = len(df)
|
||||
values = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
image_key = self._resolve_image_key(df)
|
||||
if image_key is None:
|
||||
logger.warning("No image key found for episode %d; returning zero values.", record.episode_index)
|
||||
return values
|
||||
|
||||
# Determine which frame indices actually need inference
|
||||
infer_indices = np.where(~skip_mask)[0] if skip_mask is not None else np.arange(num_frames)
|
||||
|
||||
if len(infer_indices) == 0:
|
||||
return values
|
||||
|
||||
task_text = record.episode_task
|
||||
|
||||
for batch_start in range(0, len(infer_indices), self.config.batch_size):
|
||||
batch_end = min(batch_start + self.config.batch_size, len(infer_indices))
|
||||
batch_indices = infer_indices[batch_start:batch_end]
|
||||
batch_images = []
|
||||
|
||||
for idx in batch_indices:
|
||||
img_val = df.iloc[idx][image_key]
|
||||
if isinstance(img_val, np.ndarray):
|
||||
img_tensor = torch.from_numpy(img_val).float()
|
||||
elif isinstance(img_val, torch.Tensor):
|
||||
img_tensor = img_val.float()
|
||||
else:
|
||||
img_tensor = torch.zeros(3, 224, 224)
|
||||
batch_images.append(img_tensor)
|
||||
|
||||
batch_images_tensor = torch.stack(batch_images)
|
||||
batch_size = batch_images_tensor.shape[0]
|
||||
|
||||
raw_batch = {
|
||||
image_key: batch_images_tensor,
|
||||
"task": [task_text] * batch_size,
|
||||
}
|
||||
|
||||
processed = self._preprocessor(raw_batch)
|
||||
|
||||
with torch.no_grad():
|
||||
v_values = self._model.compute_reward(processed)
|
||||
|
||||
values[batch_indices] = v_values.cpu().numpy()
|
||||
|
||||
return values
|
||||
|
||||
def _compute_n_step_advantages(
|
||||
self, mc_returns: np.ndarray, values: np.ndarray, record: EpisodeRecord, n: int
|
||||
) -> np.ndarray:
|
||||
"""Compute N-step advantage: A_t = Σ r_{t:t+N-1} + V(s_{t+N}) - V(s_t).
|
||||
|
||||
When t+N exceeds episode length, truncates to MC (uses mc_return directly).
|
||||
"""
|
||||
num_frames = len(values)
|
||||
advantages = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
for t in range(num_frames):
|
||||
if t + n >= num_frames:
|
||||
advantages[t] = mc_returns[t] - values[t]
|
||||
else:
|
||||
n_step_return = mc_returns[t] - mc_returns[t + n]
|
||||
advantages[t] = n_step_return + values[t + n] - values[t]
|
||||
|
||||
return advantages
|
||||
|
||||
def _resolve_image_key(self, df) -> str | None:
|
||||
"""Find the first image observation key in the dataframe columns."""
|
||||
for col in df.columns:
|
||||
if col.startswith("observation.images."):
|
||||
return col
|
||||
return None
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
"""Score one episode and write advantage rows to staging."""
|
||||
if not self.config.value_function_path:
|
||||
logger.warning("No value_function_path configured; skipping advantage scoring.")
|
||||
return
|
||||
|
||||
advantages, intervention_mask = self.compute_advantages_for_episode(record)
|
||||
num_frames = len(advantages)
|
||||
|
||||
threshold = self._compute_threshold(advantages, intervention_mask)
|
||||
|
||||
rng = np.random.default_rng(seed=hash((record.episode_index, 42)) & 0xFFFFFFFF)
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for t in range(num_frames):
|
||||
if rng.random() < self.config.dropout_rate:
|
||||
continue
|
||||
|
||||
if (
|
||||
self.config.force_positive_on_intervention
|
||||
and intervention_mask[t]
|
||||
or advantages[t] > threshold
|
||||
):
|
||||
indicator = "positive"
|
||||
else:
|
||||
indicator = "negative"
|
||||
|
||||
timestamp = float(record.frame_timestamps[t]) if t < len(record.frame_timestamps) else 0.0
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": indicator,
|
||||
"style": "advantage",
|
||||
"timestamp": timestamp,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
staging.write("advantage", rows)
|
||||
logger.debug(
|
||||
"Episode %d: %d/%d frames scored (threshold=%.4f, %d positive, %d negative)",
|
||||
record.episode_index,
|
||||
len(rows),
|
||||
num_frames,
|
||||
threshold,
|
||||
sum(1 for r in rows if r["content"] == "positive"),
|
||||
sum(1 for r in rows if r["content"] == "negative"),
|
||||
)
|
||||
|
||||
def _compute_threshold(self, advantages: np.ndarray, intervention_mask: np.ndarray) -> float:
|
||||
"""Compute the binarization threshold as the configured percentile of advantages."""
|
||||
non_intervention = advantages[~intervention_mask] if intervention_mask.any() else advantages
|
||||
if len(non_intervention) == 0:
|
||||
return 0.0
|
||||
return float(np.percentile(non_intervention, self.config.threshold_percentile * 100))
|
||||
@@ -39,6 +39,7 @@ _MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
"advantage",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ DEFAULT_BINDINGS = {
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
"advantage": "active_at(t, style=advantage)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
# RECAP advantage-conditioned recipe.
|
||||
#
|
||||
# Composes task + advantage indicator into the prompt for conditional SFT.
|
||||
# The advantage binding resolves to "positive" or "negative" from the
|
||||
# language_persistent column (written by lerobot-annotate --advantage).
|
||||
# When advantage is absent (30% dropout), the advantage turn is skipped
|
||||
# entirely via if_present, training the unconditional branch for CFG.
|
||||
#
|
||||
# This recipe is policy-agnostic: any VLA that consumes chat-style messages
|
||||
# can use it. Override bindings or add blend components for task-specific needs.
|
||||
#
|
||||
# Paper: pi*0.6, Section IV-B (conditional policy training with I_t).
|
||||
|
||||
bindings:
|
||||
advantage: "active_at(t, style=advantage)"
|
||||
|
||||
messages:
|
||||
- role: user
|
||||
content: "${task}"
|
||||
stream: high_level
|
||||
|
||||
- role: user
|
||||
content: "Advantage: ${advantage}"
|
||||
stream: high_level
|
||||
if_present: advantage
|
||||
|
||||
- role: assistant
|
||||
content: "${subtask}"
|
||||
stream: low_level
|
||||
target: true
|
||||
@@ -0,0 +1,41 @@
|
||||
# RECAP full recipe with advantage conditioning and subtask blending.
|
||||
#
|
||||
# Blend of two training modes:
|
||||
# 1. advantage_conditioned (70%): Task + advantage indicator → action
|
||||
# 2. unconditional (30%): Task only → action (no advantage, trains CFG baseline)
|
||||
#
|
||||
# This achieves the same effect as per-frame dropout in the annotation module
|
||||
# but at the recipe level, giving explicit control over the conditioning ratio.
|
||||
# Use this instead of annotation-level dropout if you want a fixed split.
|
||||
#
|
||||
# Paper: pi*0.6, Appendix E (classifier-free guidance requires both branches).
|
||||
|
||||
blend:
|
||||
advantage_conditioned:
|
||||
weight: 0.7
|
||||
messages:
|
||||
- role: user
|
||||
content: "${task}\nAdvantage: ${advantage}"
|
||||
stream: high_level
|
||||
if_present: advantage
|
||||
|
||||
- role: user
|
||||
content: "${task}"
|
||||
stream: high_level
|
||||
|
||||
- role: assistant
|
||||
content: "${subtask}"
|
||||
stream: low_level
|
||||
target: true
|
||||
|
||||
unconditional:
|
||||
weight: 0.3
|
||||
messages:
|
||||
- role: user
|
||||
content: "${task}"
|
||||
stream: high_level
|
||||
|
||||
- role: assistant
|
||||
content: "${subtask}"
|
||||
stream: low_level
|
||||
target: true
|
||||
@@ -43,10 +43,10 @@ CORE_STYLES = {
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
EXTENDED_STYLES: set[str] = {"advantage"}
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug", "advantage"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
|
||||
@@ -87,6 +87,17 @@ class PI05Config(PreTrainedConfig):
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
|
||||
# Language conditioning (e.g. RECAP advantage). When set, RenderMessagesStep
|
||||
# is inserted into the preprocessor to resolve language_persistent rows via
|
||||
# the recipe YAML before prompt construction.
|
||||
recipe_path: str | None = None
|
||||
|
||||
# Classifier-Free Guidance (CFG) scale for inference (Eq. 13 in RECAP paper).
|
||||
# 1.0 = no guidance (default). >1.0 enables dual-path denoising where:
|
||||
# v = v_uncond + cfg_beta * (v_cond - v_uncond)
|
||||
# VLM runs twice (cond + uncond prompts), action expert runs 2x per step.
|
||||
cfg_beta: float = 1.0
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
|
||||
@@ -52,6 +52,8 @@ from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -148,6 +150,20 @@ def clone_past_key_values(past_key_values):
|
||||
)
|
||||
|
||||
|
||||
def cat_past_key_values(kv_a, kv_b):
|
||||
"""Concatenate two DynamicCaches along the batch dimension for batched CFG."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(
|
||||
torch.cat([ka, kb], dim=0),
|
||||
torch.cat([va, vb], dim=0),
|
||||
sw_a,
|
||||
)
|
||||
for (ka, va, sw_a), (kb, vb, _sw_b) in zip(kv_a, kv_b, strict=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -797,9 +813,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
uncond_tokens=None,
|
||||
uncond_masks=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
"""Do a full inference forward and compute the action.
|
||||
|
||||
When cfg_beta > 1.0 and uncond_tokens/uncond_masks are provided, performs
|
||||
Classifier-Free Guidance: VLM runs twice (conditioned + unconditional), action
|
||||
expert runs twice per denoising step, and velocities are interpolated via
|
||||
v = v_uncond + cfg_beta * (v_cond - v_uncond).
|
||||
"""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
|
||||
@@ -815,6 +839,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
cfg_enabled = self.config.cfg_beta > 1.0 and uncond_tokens is not None and uncond_masks is not None
|
||||
|
||||
# Prefill VLM for conditioned prompt
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
@@ -830,6 +857,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Prefill VLM for unconditional prompt (CFG)
|
||||
if cfg_enabled:
|
||||
uncond_prefix_embs, uncond_prefix_pad_masks, uncond_prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, uncond_tokens, uncond_masks
|
||||
)
|
||||
uncond_prefix_att_2d_masks = make_att_2d_masks(uncond_prefix_pad_masks, uncond_prefix_att_masks)
|
||||
uncond_prefix_position_ids = torch.cumsum(uncond_prefix_pad_masks, dim=1) - 1
|
||||
uncond_prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(uncond_prefix_att_2d_masks)
|
||||
|
||||
_, uncond_past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=uncond_prefix_att_2d_masks_4d,
|
||||
position_ids=uncond_prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[uncond_prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
x_t = noise
|
||||
@@ -838,6 +882,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
if cfg_enabled:
|
||||
return self.denoise_step_cfg_batched(
|
||||
cond_prefix_pad_masks=prefix_pad_masks,
|
||||
cond_past_key_values=past_key_values,
|
||||
uncond_prefix_pad_masks=uncond_prefix_pad_masks,
|
||||
uncond_past_key_values=uncond_past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
@@ -907,6 +960,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
def denoise_step_cfg_batched(
|
||||
self,
|
||||
cond_prefix_pad_masks,
|
||||
cond_past_key_values,
|
||||
uncond_prefix_pad_masks,
|
||||
uncond_past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Batched CFG denoising: runs cond + uncond in a single forward pass.
|
||||
|
||||
Concatenates cond and uncond inputs along the batch dimension, runs one
|
||||
action expert forward (2x batch), then splits and applies CFG interpolation.
|
||||
This is ~1.5x faster than two sequential denoise_step calls due to better
|
||||
GPU utilization (inspired by Qwen2.5-Omni DiT / diffusers batched CFG).
|
||||
"""
|
||||
# Embed suffix once (same x_t and timestep for both branches)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
||||
|
||||
bsize = cond_prefix_pad_masks.shape[0]
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
cond_prefix_len = cond_prefix_pad_masks.shape[1]
|
||||
uncond_prefix_len = uncond_prefix_pad_masks.shape[1]
|
||||
|
||||
# Build attention masks for cond branch
|
||||
cond_prefix_2d = cond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, cond_prefix_len)
|
||||
cond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
cond_full_att = torch.cat([cond_prefix_2d, cond_suffix_att_2d], dim=2)
|
||||
cond_prefix_offsets = torch.sum(cond_prefix_pad_masks, dim=-1)[:, None]
|
||||
cond_position_ids = cond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Build attention masks for uncond branch
|
||||
uncond_prefix_2d = uncond_prefix_pad_masks[:, None, :].expand(bsize, suffix_len, uncond_prefix_len)
|
||||
uncond_suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
uncond_full_att = torch.cat([uncond_prefix_2d, uncond_suffix_att_2d], dim=2)
|
||||
uncond_prefix_offsets = torch.sum(uncond_prefix_pad_masks, dim=-1)[:, None]
|
||||
uncond_position_ids = uncond_prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Concatenate on batch dim: [cond_batch; uncond_batch]
|
||||
batched_full_att = torch.cat([cond_full_att, uncond_full_att], dim=0)
|
||||
batched_full_att_4d = self._prepare_attention_masks_4d(batched_full_att)
|
||||
batched_position_ids = torch.cat([cond_position_ids, uncond_position_ids], dim=0)
|
||||
batched_suffix_embs = torch.cat([suffix_embs, suffix_embs], dim=0)
|
||||
batched_adarms_cond = torch.cat([adarms_cond, adarms_cond], dim=0)
|
||||
|
||||
# Concatenate KV caches on batch dim
|
||||
batched_past_kv = cat_past_key_values(
|
||||
clone_past_key_values(cond_past_key_values),
|
||||
clone_past_key_values(uncond_past_key_values),
|
||||
)
|
||||
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
# Single forward pass for both branches
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=batched_full_att_4d,
|
||||
position_ids=batched_position_ids,
|
||||
past_key_values=batched_past_kv,
|
||||
inputs_embeds=[None, batched_suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, batched_adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_all = self.action_out_proj(suffix_out)
|
||||
|
||||
# Split: first half = cond, second half = uncond
|
||||
v_cond, v_uncond = v_all.chunk(2, dim=0)
|
||||
|
||||
# CFG interpolation: v = v_uncond + beta * (v_cond - v_uncond)
|
||||
return v_uncond + self.config.cfg_beta * (v_cond - v_uncond)
|
||||
|
||||
|
||||
class PI05Policy(PreTrainedPolicy):
|
||||
"""PI05 Policy for LeRobot."""
|
||||
@@ -1243,8 +1370,20 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# CFG: pass unconditional tokens if available
|
||||
uncond_tokens = batch.get(f"{OBS_LANGUAGE_UNCOND_TOKENS}")
|
||||
uncond_masks = batch.get(f"{OBS_LANGUAGE_UNCOND_ATTENTION_MASK}")
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
actions = self.model.sample_actions(
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
uncond_tokens=uncond_tokens,
|
||||
uncond_masks=uncond_masks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -40,6 +40,8 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
@@ -57,6 +59,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
cfg_enabled: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
@@ -84,8 +87,25 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
|
||||
# Build unconditional prompts for CFG (same state but original task without advantage)
|
||||
if self.cfg_enabled:
|
||||
base_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("base_task")
|
||||
if base_tasks is None:
|
||||
base_tasks = tasks
|
||||
|
||||
if isinstance(base_tasks, str):
|
||||
base_tasks = [base_tasks] * len(tasks)
|
||||
|
||||
uncond_prompts = []
|
||||
for i, base_task in enumerate(base_tasks):
|
||||
cleaned_text = base_task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
uncond_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
uncond_prompts.append(uncond_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["uncond_task"] = uncond_prompts
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
@@ -111,9 +131,10 @@ def make_pi05_pre_post_processors(
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
4. (Optional) Rendering language annotations via recipe YAML.
|
||||
5. (Optional) Flattening rendered messages into the task string.
|
||||
6. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
7. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
@@ -122,8 +143,6 @@ def make_pi05_pre_post_processors(
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
@@ -147,16 +166,51 @@ def make_pi05_pre_post_processors(
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# Insert language rendering steps when a recipe is configured (e.g. RECAP advantage)
|
||||
if config.recipe_path is not None:
|
||||
from lerobot.configs.recipe import load_recipe
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep
|
||||
|
||||
recipe = load_recipe(config.recipe_path)
|
||||
input_steps.append(RenderMessagesStep(recipe=recipe))
|
||||
input_steps.append(RenderedMessagesToTaskStep())
|
||||
|
||||
cfg_enabled = config.cfg_beta > 1.0
|
||||
|
||||
input_steps.extend(
|
||||
[
|
||||
Pi05PrepareStateTokenizerProcessorStep(
|
||||
max_state_dim=config.max_state_dim,
|
||||
cfg_enabled=cfg_enabled,
|
||||
),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Add unconditional prompt tokenizer for CFG inference
|
||||
if cfg_enabled:
|
||||
input_steps.append(
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
task_key="uncond_task",
|
||||
output_tokens_key=OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
output_mask_key=OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
)
|
||||
)
|
||||
|
||||
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Adapter step that flattens rendered chat messages back into a task string.
|
||||
|
||||
Bridges RenderMessagesStep (which outputs structured messages) to policies
|
||||
that expect a plain task string in complementary_data["task"] (e.g. PI05).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .pipeline import ComplementaryDataProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="rendered_messages_to_task")
|
||||
class RenderedMessagesToTaskStep(ComplementaryDataProcessorStep):
|
||||
"""Extract user-role message content from rendered messages into the task string.
|
||||
|
||||
After RenderMessagesStep renders a recipe into structured messages, this
|
||||
step extracts content from all user-role messages, joins them, and writes
|
||||
the result to complementary_data["task"]. This allows downstream steps
|
||||
(like Pi05PrepareStateTokenizerProcessorStep) to consume the
|
||||
advantage-conditioned prompt without modification.
|
||||
|
||||
No-ops when the "messages" key is absent (backward compatible with
|
||||
pipelines that don't use language annotations).
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
messages = complementary_data.get("messages")
|
||||
if messages is None:
|
||||
return complementary_data
|
||||
|
||||
user_parts = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str) and content:
|
||||
user_parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
# HF multimodal blocks: extract text blocks
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if text:
|
||||
user_parts.append(text)
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
if user_parts:
|
||||
task = complementary_data.get("task")
|
||||
# Preserve the original task for CFG unconditional prompt
|
||||
new_complementary_data["base_task"] = task
|
||||
# Wrap in list if the original task was a list (batched)
|
||||
joined = "\n".join(user_parts)
|
||||
if isinstance(task, list):
|
||||
new_complementary_data["task"] = [joined] * len(task)
|
||||
else:
|
||||
new_complementary_data["task"] = joined
|
||||
|
||||
# Remove consumed rendering outputs
|
||||
new_complementary_data.pop("messages", None)
|
||||
new_complementary_data.pop("message_streams", None)
|
||||
new_complementary_data.pop("target_message_indices", None)
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -81,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
output_tokens_key: str = OBS_LANGUAGE_TOKENS
|
||||
output_mask_key: str = OBS_LANGUAGE_ATTENTION_MASK
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
@@ -201,8 +203,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
new_observation[self.output_tokens_key] = tokenized_prompt["input_ids"]
|
||||
new_observation[self.output_mask_key] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
@@ -309,14 +311,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Add a feature for the token IDs if it doesn't already exist
|
||||
if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
|
||||
if self.output_tokens_key not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][self.output_tokens_key] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add a feature for the attention mask if it doesn't already exist
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
if self.output_mask_key not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][self.output_mask_key] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig as DistributionalVFConfig,
|
||||
)
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"DistributionalVFConfig",
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"DistributionalVFConfig",
|
||||
"DistributionalVFRewardModel",
|
||||
"make_distributional_vf_pre_post_processors",
|
||||
]
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||
@dataclass
|
||||
class DistributionalVFConfig(RewardModelConfig):
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||
(Eq. 1 in the paper).
|
||||
|
||||
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||
about 670M params and initialize from the PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
# Backbone
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
num_hidden_layers: int = 6
|
||||
num_vision_layers: int = 13
|
||||
|
||||
# Distributional head
|
||||
num_value_bins: int = 201
|
||||
value_support_min: float = -1.0
|
||||
value_support_max: float = 0.0
|
||||
hl_gauss_sigma_ratio: float = 5.0
|
||||
|
||||
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||
target_method: str = "hl_gauss"
|
||||
|
||||
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||
# When False, terminal states use the same target method as non-terminal states.
|
||||
use_one_hot_terminal: bool = True
|
||||
|
||||
# Image
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 64
|
||||
|
||||
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||
init_from_actor_path: str = ""
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=1e-4,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.input_features:
|
||||
return
|
||||
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||
if not has_image:
|
||||
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||
+567
@@ -0,0 +1,567 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modeling for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
|
||||
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaRMSNorm,
|
||||
_gated_residual,
|
||||
_get_pi_gemma_decoder_layer_base,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
PiGemmaRMSNorm = None
|
||||
_gated_residual = None
|
||||
_get_pi_gemma_decoder_layer_base = None
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257152
|
||||
|
||||
|
||||
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||
"""Distributional value function model for RECAP.
|
||||
|
||||
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||
per-task normalized Monte Carlo returns.
|
||||
|
||||
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||
"""
|
||||
|
||||
name = "distributional_value_function"
|
||||
config_class = DistributionalVFConfig
|
||||
|
||||
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||
require_package("transformers", extra="recap")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||
|
||||
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||
base_config = get_gemma_config(config.paligemma_variant)
|
||||
hidden_dim = base_config.width
|
||||
mlp_dim = base_config.mlp_dim
|
||||
num_layers = config.num_hidden_layers
|
||||
|
||||
# HuggingFace GemmaConfig for transformer layers
|
||||
gemma_config = CONFIG_MAPPING["gemma"](
|
||||
head_dim=base_config.head_dim,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=mlp_dim,
|
||||
num_attention_heads=base_config.num_heads,
|
||||
num_hidden_layers=num_layers,
|
||||
num_key_value_heads=base_config.num_kv_heads,
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
)
|
||||
self.gemma_config = gemma_config
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_value_bins = config.num_value_bins
|
||||
|
||||
# Single learned [CLS] token for value prediction
|
||||
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||
|
||||
# Value projection head: Linear(hidden_dim, num_bins)
|
||||
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||
|
||||
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||
|
||||
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||
paligemma_config.text_config = gemma_config
|
||||
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||
paligemma_config.vision_config.intermediate_size = 4304
|
||||
paligemma_config.vision_config.projection_dim = 2048
|
||||
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||
|
||||
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||
self.vision_tower = paligemma_full.model.vision_tower
|
||||
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||
del paligemma_full
|
||||
|
||||
# Truncate vision tower to num_vision_layers
|
||||
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||
vision_encoder = self.vision_tower.vision_model.encoder
|
||||
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||
|
||||
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||
|
||||
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||
if config.init_from_actor_path:
|
||||
self._initialize_from_actor()
|
||||
|
||||
def _initialize_from_actor(self) -> None:
|
||||
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||
|
||||
Called on first training run only (when init_from_actor_path is set).
|
||||
"""
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||
actor_model = actor_policy.model
|
||||
|
||||
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||
source_language_model = paligemma_model.model.language_model
|
||||
|
||||
# Transformer components
|
||||
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||
num_layers = self.gemma_config.num_hidden_layers
|
||||
for i in range(num_layers):
|
||||
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||
|
||||
# Vision tower (truncate source first, then copy)
|
||||
source_vision_tower = paligemma_model.model.vision_tower
|
||||
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||
source_vision_tower.vision_model, "encoder"
|
||||
):
|
||||
source_encoder = source_vision_tower.vision_model.encoder
|
||||
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||
|
||||
# Multi-modal projector
|
||||
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||
|
||||
# Token embedding table
|
||||
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||
|
||||
del actor_policy
|
||||
|
||||
def embed_image(self, image: Tensor) -> Tensor:
|
||||
"""Embed images using the value function's SigLIP vision tower.
|
||||
|
||||
Args:
|
||||
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||
|
||||
Returns:
|
||||
[batch_size, num_patches, hidden_dim] projected image features.
|
||||
"""
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
image_outputs = self.vision_tower(image, return_dict=True)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
image_features = image_features / (self.hidden_dim**0.5)
|
||||
|
||||
if image_features.dtype != out_dtype:
|
||||
image_features = image_features.to(out_dtype)
|
||||
return image_features
|
||||
|
||||
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||
"""Embed text token IDs using the value function's token embedding table.
|
||||
|
||||
Args:
|
||||
token_ids: [batch_size, seq_len] integer token IDs
|
||||
|
||||
Returns:
|
||||
[batch_size, seq_len, hidden_dim] text embeddings
|
||||
"""
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||
"""Get [CLS] token embedding expanded to batch size.
|
||||
|
||||
Args:
|
||||
batch_size: number of samples in the batch.
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||
"""
|
||||
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||
|
||||
def forward_value(
|
||||
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||
) -> dict[str, Tensor]:
|
||||
"""Core forward pass through the distributional value function.
|
||||
|
||||
Args:
|
||||
vision_features: [batch_size, num_patches, hidden_dim]
|
||||
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, num_value_bins]
|
||||
probs: [batch_size, num_value_bins]
|
||||
value: [batch_size, 1]
|
||||
"""
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
device = text_embeddings.device
|
||||
|
||||
# Build sequence: [vision, text, CLS]
|
||||
cls_embedding = self._get_cls_embedding(batch_size)
|
||||
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||
|
||||
# Build causal attention mask
|
||||
vision_len = vision_features.shape[1]
|
||||
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||
|
||||
full_seq_len = full_padding_mask.shape[1]
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||
# Combine causal mask with padding mask
|
||||
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||
batch_size, 1, full_seq_len, full_seq_len
|
||||
)
|
||||
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
hidden_states_normed, gate = norm_output
|
||||
else:
|
||||
hidden_states_normed, gate = norm_output, None
|
||||
|
||||
input_shape = hidden_states_normed.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||
layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
layer.self_attn.scaling,
|
||||
)
|
||||
|
||||
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||
|
||||
if gate is not None:
|
||||
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||
else:
|
||||
projected_attention = hidden_states + projected_attention
|
||||
|
||||
after_attention_residual = projected_attention.clone()
|
||||
|
||||
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
mlp_input, gate = norm_output
|
||||
else:
|
||||
mlp_input, gate = norm_output, None
|
||||
|
||||
mlp_output = layer.mlp(mlp_input)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||
else:
|
||||
hidden_states = after_attention_residual + mlp_output
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
# Extract [CLS] token (last position in the sequence)
|
||||
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||
|
||||
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||
value_probs = F.softmax(value_logits, dim=-1)
|
||||
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||
|
||||
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||
"""HL-Gauss soft target distribution.
|
||||
|
||||
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||
|
||||
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||
arXiv:2403.03950
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
# Bin edges: half a bin-width outside the first/last center
|
||||
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||
self.num_value_bins - 1
|
||||
)
|
||||
support_edges = torch.linspace(
|
||||
self.config.value_support_min - bin_width / 2,
|
||||
self.config.value_support_max + bin_width / 2,
|
||||
self.num_value_bins + 1,
|
||||
device=target_value.device,
|
||||
dtype=target_value.dtype,
|
||||
)
|
||||
|
||||
# CDF of N(target, sigma^2) evaluated at each edge
|
||||
cdf_at_edges = 0.5 * (
|
||||
1.0
|
||||
+ torch.erf(
|
||||
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||
)
|
||||
) # [batch_size, num_bins + 1]
|
||||
|
||||
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||
|
||||
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||
|
||||
return bin_probabilities
|
||||
|
||||
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||
|
||||
Standard distributional RL projection from Bellemare et al. 2017.
|
||||
"A Distributional Perspective on Reinforcement Learning"
|
||||
arXiv:1707.06887
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||
|
||||
weight_upper = normalized_position - lower_bin_idx.float()
|
||||
weight_lower = upper_bin_idx.float() - normalized_position
|
||||
|
||||
same_bin = lower_bin_idx == upper_bin_idx
|
||||
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||
|
||||
batch_size = target_value.shape[0]
|
||||
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||
|
||||
return target_distribution
|
||||
|
||||
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||
"""One-hot target for terminal states (exact return, no smoothing).
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
nearest_bin_idx = torch.argmin(
|
||||
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||
)
|
||||
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||
|
||||
def compute_target_distribution(
|
||||
self,
|
||||
target_value: Tensor,
|
||||
is_terminal: Tensor,
|
||||
method: str = "hl_gauss",
|
||||
use_one_hot_terminal: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute target distribution using configured method.
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] scalar return targets
|
||||
is_terminal: [batch_size] boolean terminal flags
|
||||
method: "hl_gauss" or "dirac_delta"
|
||||
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||
(exact return, no smoothing). If False, all states use the same method.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution
|
||||
"""
|
||||
if method == "hl_gauss":
|
||||
base_distribution = self.hl_gauss_target(target_value)
|
||||
elif method == "dirac_delta":
|
||||
base_distribution = self.dirac_delta_target(target_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||
|
||||
if not use_one_hot_terminal:
|
||||
return base_distribution
|
||||
|
||||
terminal_distribution = self.one_hot_target(target_value)
|
||||
|
||||
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||
|
||||
The batch is expected to be preprocessed by the processor pipeline.
|
||||
Keys expected in batch:
|
||||
- observation.images.*: [B, C, H, W] preprocessed images
|
||||
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||
- is_terminal: [B] boolean terminal flags
|
||||
|
||||
Returns:
|
||||
(loss, output_dict) where loss is scalar cross-entropy
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
# Get first image key from batch
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
mc_return = batch["mc_return"]
|
||||
is_terminal = batch["is_terminal"]
|
||||
|
||||
# Embed observations
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
# Forward through value function transformer
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
value_logits = vf_output["logits"]
|
||||
predicted_value = vf_output["value"]
|
||||
|
||||
# Compute target distribution
|
||||
target_distribution = self.compute_target_distribution(
|
||||
mc_return,
|
||||
is_terminal,
|
||||
method=self.config.target_method,
|
||||
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||
)
|
||||
|
||||
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||
|
||||
output_dict = {
|
||||
"loss": loss.item(),
|
||||
"predicted_value_mean": predicted_value.mean().item(),
|
||||
"mc_return_mean": mc_return.mean().item(),
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||
|
||||
Args:
|
||||
batch: preprocessed batch with images and tokenized text
|
||||
|
||||
Returns:
|
||||
[batch_size] tensor of predicted values V(s)
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||
+235
@@ -0,0 +1,235 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processor for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||
|
||||
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||
They are dataset columns read directly from the batch in the model's forward().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||
@dataclass
|
||||
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||
"""Format the task string for the distributional value function.
|
||||
|
||||
The value function receives only visual observations and task text.
|
||||
Builds prompt: "Task: {task}."
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
tasks = complementary_data.get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
full_prompts = []
|
||||
for task in tasks:
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
full_prompts.append(f"Task: {cleaned_text}.")
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||
@dataclass
|
||||
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||
|
||||
Expects float images in [0, 1].
|
||||
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||
- Scale to [-1, 1] for SigLIP
|
||||
"""
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
image_keys: tuple[str, ...] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||
|
||||
image_keys = self.image_keys or tuple(
|
||||
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError(
|
||||
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||
)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for image_key in image_keys:
|
||||
image = new_observation[image_key]
|
||||
if not isinstance(image, Tensor):
|
||||
image = to_tensor(image)
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != self.image_resolution:
|
||||
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
|
||||
new_observation[image_key] = image
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_resolution": self.image_resolution,
|
||||
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||
return tuple(
|
||||
feature_name
|
||||
for feature_name, feature in config.input_features.items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
)
|
||||
|
||||
|
||||
def make_distributional_vf_pre_post_processors(
|
||||
config: DistributionalVFConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre/post processors for the distributional value function.
|
||||
|
||||
Preprocessor steps:
|
||||
1. Rename observations (no-op by default)
|
||||
2. Add a batch dimension
|
||||
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||
4. Format task prompt: "Task: {task}."
|
||||
5. Tokenize with the PaliGemma tokenizer
|
||||
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||
7. Move tensors to the configured device
|
||||
|
||||
Training targets (mc_return, is_terminal) are not processed here.
|
||||
The model reads them directly from the batch in forward().
|
||||
|
||||
The postprocessor is a no-op because the value function does not need
|
||||
action postprocessing.
|
||||
"""
|
||||
image_keys = _visual_image_keys(config)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DistributionalVFPrepareTaskPromptStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DistributionalVFImagePreprocessorStep(
|
||||
image_resolution=config.image_resolution,
|
||||
image_keys=image_keys or None,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
elif name == "distributional_value_function":
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
elif reward_type == "distributional_value_function":
|
||||
return DistributionalVFConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
|
||||
return make_distributional_vf_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
|
||||
pause_resume: str = "space"
|
||||
correction: str = "tab"
|
||||
upload: str = "enter"
|
||||
success: str = "s"
|
||||
failure: str = "f"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -119,6 +121,8 @@ class DAggerPedalConfig:
|
||||
pause_resume: str = "KEY_A"
|
||||
correction: str = "KEY_B"
|
||||
upload: str = "KEY_C"
|
||||
success: str = "KEY_D"
|
||||
failure: str = "KEY_E"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@@ -165,6 +169,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
2. **correction** — toggle human correction recording.
|
||||
3. **upload** — push dataset to hub on demand (corrections-only mode).
|
||||
|
||||
Episode success labeling:
|
||||
4. **success** — mark current episode as successful.
|
||||
5. **failure** — mark current episode as failed.
|
||||
|
||||
When ``record_autonomous=False`` (default) only human-correction windows
|
||||
are recorded — each correction becomes its own episode. Set to ``True``
|
||||
to record both autonomous and correction frames with size-based episode
|
||||
|
||||
@@ -347,6 +347,11 @@ def build_rollout_context(
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
dataset_features["next.success"] = {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
|
||||
if not repo_name.startswith("rollout_"):
|
||||
|
||||
@@ -129,6 +129,9 @@ class DAggerEvents:
|
||||
self.stop_recording = Event()
|
||||
self.upload_requested = Event()
|
||||
|
||||
# Episode success labeling
|
||||
self._episode_success: bool | None = None
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@property
|
||||
@@ -171,8 +174,26 @@ class DAggerEvents:
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self._episode_success = None
|
||||
self.upload_requested.clear()
|
||||
|
||||
def mark_success(self) -> None:
|
||||
"""Mark the current episode as successful (called from input threads)."""
|
||||
with self._lock:
|
||||
self._episode_success = True
|
||||
|
||||
def mark_failure(self) -> None:
|
||||
"""Mark the current episode as failed (called from input threads)."""
|
||||
with self._lock:
|
||||
self._episode_success = False
|
||||
|
||||
def consume_episode_success(self) -> bool | None:
|
||||
"""Consume and reset the episode success label. Returns None if unlabeled."""
|
||||
with self._lock:
|
||||
result = self._episode_success
|
||||
self._episode_success = None
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
@@ -226,16 +247,25 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
events.request_transition(key_to_event[resolved])
|
||||
if resolved == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
if resolved == cfg.success:
|
||||
events.mark_success()
|
||||
logger.info("Episode marked as SUCCESS")
|
||||
if resolved == cfg.failure:
|
||||
events.mark_failure()
|
||||
logger.info("Episode marked as FAILURE")
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', "
|
||||
"upload='%s', success='%s', failure='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
cfg.success,
|
||||
cfg.failure,
|
||||
)
|
||||
return listener
|
||||
|
||||
@@ -255,6 +285,12 @@ def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
events.request_transition(code_to_event[code])
|
||||
if code == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
if code == cfg.success:
|
||||
events.mark_success()
|
||||
logger.info("Episode marked as SUCCESS (pedal)")
|
||||
if code == cfg.failure:
|
||||
events.mark_failure()
|
||||
logger.info("Episode marked as FAILURE (pedal)")
|
||||
|
||||
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
|
||||
return start_pedal_listener(on_press, device_path=cfg.device_path)
|
||||
@@ -357,6 +393,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode success labeling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _stamp_episode_success(self, dataset) -> None:
|
||||
"""Set next.success on the terminal frame based on operator label.
|
||||
|
||||
Called just before save_episode(). If the operator pressed the success
|
||||
key during this episode, the last frame's next.success is set to True.
|
||||
Otherwise all frames remain False (unlabeled = assumed failure).
|
||||
"""
|
||||
buf = dataset.writer.episode_buffer
|
||||
if buf is None:
|
||||
return
|
||||
|
||||
success_buf = buf.get("next.success")
|
||||
if not success_buf:
|
||||
return
|
||||
|
||||
label = self._events.consume_episode_success()
|
||||
|
||||
if label:
|
||||
success_buf[-1] = np.array([True], dtype=bool)
|
||||
logger.info("Terminal frame stamped next.success=True")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Continuous recording mode (record_autonomous=True)
|
||||
# ------------------------------------------------------------------
|
||||
@@ -443,6 +504,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
@@ -471,6 +533,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
@@ -481,6 +544,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
@@ -510,6 +574,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.pause()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
@@ -584,6 +649,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
self._needs_push.set()
|
||||
@@ -625,6 +691,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
)
|
||||
record_tick += 1
|
||||
@@ -659,6 +726,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.pause()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
@@ -34,6 +34,7 @@ from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConf
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
AdvantageModule,
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
@@ -86,6 +87,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
advantage = AdvantageModule(config=cfg.advantage)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
@@ -96,6 +98,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
advantage=advantage,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame ``is_terminal`` and ``mc_return`` for a LeRobot dataset.
|
||||
|
||||
Implements the sparse reward function from pi*0.6 / RECAP (Eq. 5):
|
||||
|
||||
r_t = -1 for non-terminal steps
|
||||
r_T = 0 for terminal success
|
||||
r_T = -C_fail for terminal failure
|
||||
|
||||
Monte Carlo returns are the cumulative sum from each step to the end of
|
||||
the episode, normalized by ``max_episode_length`` so that values are bounded
|
||||
to approximately (-1, 0).
|
||||
|
||||
The columns are written directly into the dataset's parquet data shards as
|
||||
flat per-frame scalars. These serve as training targets for the distributional
|
||||
value function.
|
||||
|
||||
Usage:
|
||||
# Compute returns using the default "next.success" column (from lerobot-eval/rollout)
|
||||
lerobot-compute-returns \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human_image
|
||||
|
||||
# Override: treat all episodes as successful (demo-only datasets)
|
||||
lerobot-compute-returns \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human_image \\
|
||||
--default-success true
|
||||
|
||||
# Custom success key, failure penalty, and discount
|
||||
lerobot-compute-returns \\
|
||||
--dataset-repo-id my_org/my_dataset \\
|
||||
--success-key episode_success \\
|
||||
--c-fail 100 \\
|
||||
--gamma 0.99
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IS_TERMINAL_COL = "is_terminal"
|
||||
MC_RETURN_COL = "mc_return"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeReturnsConfig:
|
||||
"""Configuration for the returns computation script."""
|
||||
|
||||
dataset_repo_id: str = ""
|
||||
root: str | None = None
|
||||
success_key: str = "next.success"
|
||||
default_success: bool | None = None
|
||||
max_episode_length: int | None = None
|
||||
c_fail: float = 50.0
|
||||
gamma: float = 1.0
|
||||
episodes: list[int] = field(default_factory=list)
|
||||
force: bool = False
|
||||
|
||||
|
||||
def _get_episode_success(
|
||||
episode_table: pa.Table,
|
||||
success_key: str,
|
||||
default_success: bool | None,
|
||||
) -> bool:
|
||||
"""Determine whether an episode was successful.
|
||||
|
||||
Priority:
|
||||
1. If ``default_success`` is set, use it unconditionally.
|
||||
2. Look for ``success_key`` in the parquet columns and reduce with any().
|
||||
3. Fall back to True (assume success for demo datasets).
|
||||
"""
|
||||
if default_success is not None:
|
||||
return default_success
|
||||
|
||||
if success_key in episode_table.column_names:
|
||||
col = episode_table.column(success_key)
|
||||
for val in col:
|
||||
py_val = val.as_py()
|
||||
if isinstance(py_val, bool) and py_val:
|
||||
return True
|
||||
if isinstance(py_val, (int, float)) and py_val:
|
||||
return True
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def compute_episode_returns(
|
||||
num_frames: int,
|
||||
success: bool,
|
||||
c_fail: float,
|
||||
gamma: float,
|
||||
max_episode_length: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute is_terminal and mc_return arrays for a single episode.
|
||||
|
||||
Args:
|
||||
num_frames: Number of frames in the episode.
|
||||
success: Whether the episode ended successfully.
|
||||
c_fail: Failure penalty constant.
|
||||
gamma: Discount factor (1.0 = undiscounted).
|
||||
max_episode_length: Normalization horizon H.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_terminal, mc_return) arrays, each of length num_frames.
|
||||
"""
|
||||
horizon = max_episode_length
|
||||
|
||||
rewards = np.full(num_frames, -1.0 / horizon, dtype=np.float64)
|
||||
|
||||
if success:
|
||||
rewards[-1] = 0.0
|
||||
else:
|
||||
rewards[-1] = -c_fail / horizon
|
||||
|
||||
is_terminal = np.zeros(num_frames, dtype=bool)
|
||||
is_terminal[-1] = True
|
||||
|
||||
if gamma == 1.0:
|
||||
# Reverse cumulative sum
|
||||
mc_return = np.cumsum(rewards[::-1])[::-1].astype(np.float32)
|
||||
else:
|
||||
mc_return = np.zeros(num_frames, dtype=np.float64)
|
||||
mc_return[-1] = rewards[-1]
|
||||
for t in range(num_frames - 2, -1, -1):
|
||||
mc_return[t] = rewards[t] + gamma * mc_return[t + 1]
|
||||
mc_return = mc_return.astype(np.float32)
|
||||
|
||||
return is_terminal, mc_return
|
||||
|
||||
|
||||
def compute_returns(config: ComputeReturnsConfig) -> Path:
|
||||
"""Compute returns and write them into parquet shards."""
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
logger.info(f"Loading dataset: {config.dataset_repo_id}")
|
||||
kwargs = {"repo_id": config.dataset_repo_id, "download_videos": False}
|
||||
if config.root:
|
||||
kwargs["root"] = config.root
|
||||
dataset = LeRobotDataset(**kwargs)
|
||||
|
||||
meta = dataset.meta
|
||||
root = Path(meta.root)
|
||||
logger.info(f"Dataset root: {root}")
|
||||
logger.info(f"Episodes: {meta.total_episodes}, Frames: {meta.total_frames}")
|
||||
|
||||
episode_indices = config.episodes if config.episodes else list(range(meta.total_episodes))
|
||||
|
||||
if config.max_episode_length is not None:
|
||||
max_ep_len = config.max_episode_length
|
||||
else:
|
||||
max_ep_len = max(int(meta.episodes[i]["length"]) for i in episode_indices)
|
||||
logger.info(f"Normalization horizon (max_episode_length): {max_ep_len}")
|
||||
|
||||
parquet_files_to_rewrite: dict[Path, list[int]] = {}
|
||||
for ep_idx in episode_indices:
|
||||
rel_path = meta.get_data_file_path(ep_idx)
|
||||
abs_path = root / rel_path
|
||||
parquet_files_to_rewrite.setdefault(abs_path, []).append(ep_idx)
|
||||
|
||||
logger.info(f"Parquet shards to rewrite: {len(parquet_files_to_rewrite)}")
|
||||
|
||||
for parquet_path, ep_indices_in_file in tqdm(parquet_files_to_rewrite.items(), desc="Processing shards"):
|
||||
table = pq.read_table(parquet_path)
|
||||
|
||||
if not config.force and IS_TERMINAL_COL in table.column_names:
|
||||
logger.info(f"Skipping {parquet_path.name} (already has {IS_TERMINAL_COL})")
|
||||
continue
|
||||
|
||||
all_is_terminal = np.zeros(len(table), dtype=bool)
|
||||
all_mc_return = np.zeros(len(table), dtype=np.float32)
|
||||
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
|
||||
for ep_idx in ep_indices_in_file:
|
||||
ep_info = meta.episodes[ep_idx]
|
||||
ep_from = int(ep_info["dataset_from_index"])
|
||||
ep_to = int(ep_info["dataset_to_index"])
|
||||
ep_len = ep_to - ep_from
|
||||
|
||||
mask = np.array([v == ep_idx for v in episode_col], dtype=bool)
|
||||
local_indices = np.where(mask)[0]
|
||||
|
||||
if len(local_indices) != ep_len:
|
||||
logger.warning(
|
||||
f"Episode {ep_idx}: expected {ep_len} frames in shard, "
|
||||
f"found {len(local_indices)}. Using found count."
|
||||
)
|
||||
ep_len = len(local_indices)
|
||||
|
||||
if ep_len == 0:
|
||||
continue
|
||||
|
||||
ep_subtable = table.filter(mask)
|
||||
success = _get_episode_success(ep_subtable, config.success_key, config.default_success)
|
||||
|
||||
is_terminal, mc_return = compute_episode_returns(
|
||||
num_frames=ep_len,
|
||||
success=success,
|
||||
c_fail=config.c_fail,
|
||||
gamma=config.gamma,
|
||||
max_episode_length=max_ep_len,
|
||||
)
|
||||
|
||||
all_is_terminal[local_indices] = is_terminal
|
||||
all_mc_return[local_indices] = mc_return
|
||||
|
||||
if IS_TERMINAL_COL in table.column_names:
|
||||
table = table.drop(IS_TERMINAL_COL)
|
||||
if MC_RETURN_COL in table.column_names:
|
||||
table = table.drop(MC_RETURN_COL)
|
||||
|
||||
table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal))
|
||||
table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return))
|
||||
|
||||
pq.write_table(table, parquet_path)
|
||||
|
||||
_update_info_json(root, meta)
|
||||
|
||||
logger.info("Done. Columns written: is_terminal, mc_return")
|
||||
return root
|
||||
|
||||
|
||||
def _update_info_json(root: Path, meta) -> None:
|
||||
"""Add is_terminal and mc_return to the dataset's info.json features."""
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
logger.warning(f"info.json not found at {info_path}, skipping metadata update.")
|
||||
return
|
||||
|
||||
info = json.loads(info_path.read_text())
|
||||
features = info.get("features", {})
|
||||
changed = False
|
||||
|
||||
if IS_TERMINAL_COL not in features:
|
||||
features[IS_TERMINAL_COL] = {
|
||||
"dtype": "bool",
|
||||
"shape": [1],
|
||||
"names": None,
|
||||
}
|
||||
changed = True
|
||||
|
||||
if MC_RETURN_COL not in features:
|
||||
features[MC_RETURN_COL] = {
|
||||
"dtype": "float32",
|
||||
"shape": [1],
|
||||
"names": None,
|
||||
}
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
info["features"] = features
|
||||
info_path.write_text(json.dumps(info, indent=2) + "\n")
|
||||
logger.info("Updated meta/info.json with is_terminal and mc_return features.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame is_terminal and mc_return for a LeRobot dataset.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Use the 'success' column from the dataset
|
||||
lerobot-compute-returns --dataset-repo-id lerobot/aloha_sim_insertion_human_image
|
||||
|
||||
# Override all episodes as successful (demo-only data)
|
||||
lerobot-compute-returns --dataset-repo-id my_org/my_dataset --default-success true
|
||||
|
||||
# Custom failure penalty
|
||||
lerobot-compute-returns --dataset-repo-id my_org/my_dataset --c-fail 100
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repo id or local path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local root directory override for the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--success-key",
|
||||
type=str,
|
||||
default="next.success",
|
||||
help="Column name in parquet that indicates episode success (default: 'next.success').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default-success",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["true", "false"],
|
||||
help="Override success for all episodes ('true' or 'false').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episode-length",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Normalization horizon H. If not set, uses max episode length in dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--c-fail",
|
||||
type=float,
|
||||
default=50.0,
|
||||
help="Failure penalty constant (default: 50.0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gamma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Discount factor (default: 1.0, undiscounted).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Process only these episode indices (default: all).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Overwrite existing is_terminal/mc_return columns.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
default_success = None
|
||||
if args.default_success is not None:
|
||||
default_success = args.default_success.lower() == "true"
|
||||
|
||||
config = ComputeReturnsConfig(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
root=args.root,
|
||||
success_key=args.success_key,
|
||||
default_success=default_success,
|
||||
max_episode_length=args.max_episode_length,
|
||||
c_fail=args.c_fail,
|
||||
gamma=args.gamma,
|
||||
episodes=args.episodes or [],
|
||||
force=args.force,
|
||||
)
|
||||
|
||||
root = compute_returns(config)
|
||||
logger.info(f"Returns computed and written to: {root}")
|
||||
logger.info(f" Columns added: {IS_TERMINAL_COL}, {MC_RETURN_COL}")
|
||||
logger.info("To train the distributional value function, these columns")
|
||||
logger.info("will be read as flat batch keys during training.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_UNCOND = OBS_STR + ".language_uncond"
|
||||
OBS_LANGUAGE_UNCOND_TOKENS = OBS_LANGUAGE_UNCOND + ".tokens"
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK = OBS_LANGUAGE_UNCOND + ".attention_mask"
|
||||
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
|
||||
|
||||
@@ -28,9 +28,10 @@ import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig, AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
AdvantageModule,
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
@@ -85,6 +86,7 @@ def main() -> int:
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the advantage scoring annotation module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig
|
||||
from lerobot.annotations.steerable_pipeline.modules.advantage import AdvantageModule
|
||||
from lerobot.annotations.steerable_pipeline.reader import EpisodeRecord
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
|
||||
|
||||
def _make_record(
|
||||
episode_index: int = 0,
|
||||
num_frames: int = 20,
|
||||
task: str = "pick up the cup",
|
||||
mc_returns: np.ndarray | None = None,
|
||||
intervention_mask: np.ndarray | None = None,
|
||||
fps: float = 10.0,
|
||||
) -> EpisodeRecord:
|
||||
"""Build a minimal EpisodeRecord with a mocked frames_df."""
|
||||
import pandas as pd
|
||||
|
||||
timestamps = tuple(round(i / fps, 6) for i in range(num_frames))
|
||||
frame_indices = tuple(range(num_frames))
|
||||
|
||||
if mc_returns is None:
|
||||
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": list(range(num_frames)),
|
||||
"timestamp": list(timestamps),
|
||||
"mc_return": mc_returns,
|
||||
}
|
||||
|
||||
if intervention_mask is not None:
|
||||
data["intervention"] = intervention_mask.astype(bool)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
record = EpisodeRecord(
|
||||
episode_index=episode_index,
|
||||
episode_task=task,
|
||||
frame_timestamps=timestamps,
|
||||
frame_indices=frame_indices,
|
||||
data_path=Path("/fake/data.parquet"),
|
||||
row_offset=0,
|
||||
row_count=num_frames,
|
||||
)
|
||||
record._frames_df_cache = df
|
||||
return record
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def staging(tmp_path: Path) -> EpisodeStaging:
|
||||
return EpisodeStaging(tmp_path, episode_index=0)
|
||||
|
||||
|
||||
def test_advantage_module_disabled():
|
||||
"""Disabled module has enabled=False."""
|
||||
cfg = AdvantageConfig(enabled=False)
|
||||
module = AdvantageModule(config=cfg)
|
||||
assert not module.enabled
|
||||
|
||||
|
||||
def test_advantage_module_enabled_by_default():
|
||||
"""Module is enabled by default."""
|
||||
cfg = AdvantageConfig()
|
||||
module = AdvantageModule(config=cfg)
|
||||
assert module.enabled
|
||||
|
||||
|
||||
def test_run_episode_skips_without_value_function_path(staging: EpisodeStaging):
|
||||
"""Module gracefully returns when no value_function_path is configured."""
|
||||
cfg = AdvantageConfig(value_function_path="")
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record()
|
||||
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
assert rows == []
|
||||
|
||||
|
||||
def test_binarization_with_mock_values(staging: EpisodeStaging):
|
||||
"""Advantage binarization produces positive/negative labels based on threshold."""
|
||||
num_frames = 10
|
||||
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1, -0.5, -0.6, -0.7, -0.8, -0.9], dtype=np.float32)
|
||||
mock_values = np.array([-0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4], dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.0,
|
||||
threshold_percentile=0.5,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
assert len(rows) == num_frames
|
||||
|
||||
# A_t = mc_returns - values
|
||||
# advantages = [-0.1, 0.0, 0.1, 0.2, 0.3, -0.1, -0.2, -0.3, -0.4, -0.5]
|
||||
# Median (50th pctile) = -0.1
|
||||
# positive: advantage > -0.1 → indices 1,2,3,4
|
||||
# negative: advantage <= -0.1 → indices 0,5,6,7,8,9
|
||||
positives = [r for r in rows if r["content"] == "positive"]
|
||||
negatives = [r for r in rows if r["content"] == "negative"]
|
||||
assert len(positives) == 4
|
||||
assert len(negatives) == 6
|
||||
|
||||
|
||||
def test_intervention_frames_forced_positive(staging: EpisodeStaging):
|
||||
"""Intervention frames are always scored as positive regardless of advantage value."""
|
||||
num_frames = 5
|
||||
mc_returns = np.array([-0.9, -0.9, -0.9, -0.9, -0.9], dtype=np.float32)
|
||||
mock_values = np.array([-0.1, -0.1, -0.1, -0.1, -0.1], dtype=np.float32)
|
||||
intervention = np.array([False, False, True, False, False])
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.0,
|
||||
force_positive_on_intervention=True,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns, intervention_mask=intervention)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
# Frame 2 (intervention) should be positive despite negative advantage
|
||||
assert rows[2]["content"] == "positive"
|
||||
|
||||
|
||||
def test_dropout_reduces_output_rows(staging: EpisodeStaging):
|
||||
"""Non-zero dropout rate omits some frames."""
|
||||
num_frames = 100
|
||||
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
|
||||
mock_values = np.full(num_frames, -0.5, dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.3,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
# With 30% dropout on 100 frames, expect ~70 rows (with some variance)
|
||||
assert 50 < len(rows) < 90
|
||||
|
||||
|
||||
def test_staged_row_format(staging: EpisodeStaging):
|
||||
"""Staged rows have the correct schema for language_persistent."""
|
||||
num_frames = 5
|
||||
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1], dtype=np.float32)
|
||||
mock_values = np.full(5, -0.3, dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
for row in rows:
|
||||
assert row["role"] == "user"
|
||||
assert row["content"] in ("positive", "negative")
|
||||
assert row["style"] == "advantage"
|
||||
assert isinstance(row["timestamp"], float)
|
||||
assert row["camera"] is None
|
||||
assert row["tool_calls"] is None
|
||||
|
||||
|
||||
def test_n_step_advantage():
|
||||
"""N-step advantage uses partial returns + bootstrapped value."""
|
||||
num_frames = 10
|
||||
mc_returns = np.linspace(-0.9, 0.0, num_frames).astype(np.float32)
|
||||
mock_values = np.full(num_frames, -0.45, dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
n_step=3,
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with patch.object(module, "_ensure_model_loaded"):
|
||||
advantages, _ = (
|
||||
module.compute_advantages_for_episode.__wrapped__(module, record)
|
||||
if hasattr(module.compute_advantages_for_episode, "__wrapped__")
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
# Just verify computation works - use the internal method directly
|
||||
module._model = MagicMock()
|
||||
module._preprocessor = MagicMock()
|
||||
with patch.object(module, "_compute_values", return_value=mock_values):
|
||||
advantages, _ = module.compute_advantages_for_episode(record)
|
||||
|
||||
# For t where t+n < num_frames: A = mc_return[t] - mc_return[t+n] + values[t+n] - values[t]
|
||||
# Since values are constant: A = mc_return[t] - mc_return[t+n]
|
||||
# For t where t+n >= num_frames: A = mc_return[t] - values[t]
|
||||
for t in range(num_frames):
|
||||
if t + 3 < num_frames:
|
||||
expected = mc_returns[t] - mc_returns[t + 3] + mock_values[t + 3] - mock_values[t]
|
||||
else:
|
||||
expected = mc_returns[t] - mock_values[t]
|
||||
np.testing.assert_almost_equal(advantages[t], expected, decimal=5)
|
||||
|
||||
|
||||
def test_compute_threshold():
|
||||
"""Threshold is computed as configured percentile of non-intervention advantages."""
|
||||
cfg = AdvantageConfig(threshold_percentile=0.3)
|
||||
module = AdvantageModule(config=cfg)
|
||||
|
||||
advantages = np.array([-1.0, -0.5, 0.0, 0.5, 1.0], dtype=np.float32)
|
||||
intervention_mask = np.array([False, False, False, False, False])
|
||||
|
||||
threshold = module._compute_threshold(advantages, intervention_mask)
|
||||
expected = float(np.percentile(advantages, 30))
|
||||
assert abs(threshold - expected) < 1e-6
|
||||
|
||||
|
||||
def test_compute_threshold_excludes_intervention():
|
||||
"""Threshold computation excludes intervention frames."""
|
||||
cfg = AdvantageConfig(threshold_percentile=0.5)
|
||||
module = AdvantageModule(config=cfg)
|
||||
|
||||
advantages = np.array([100.0, -1.0, 0.0, 1.0, 100.0], dtype=np.float32)
|
||||
intervention_mask = np.array([True, False, False, False, True])
|
||||
|
||||
threshold = module._compute_threshold(advantages, intervention_mask)
|
||||
# Only non-intervention: [-1.0, 0.0, 1.0], median = 0.0
|
||||
expected = float(np.percentile([-1.0, 0.0, 1.0], 50))
|
||||
assert abs(threshold - expected) < 1e-6
|
||||
|
||||
|
||||
def test_missing_mc_return_raises():
|
||||
"""Module raises if mc_return column is missing from dataset."""
|
||||
import pandas as pd
|
||||
|
||||
cfg = AdvantageConfig(value_function_path="/fake/vf")
|
||||
module = AdvantageModule(config=cfg)
|
||||
module._model = MagicMock()
|
||||
module._preprocessor = MagicMock()
|
||||
|
||||
record = EpisodeRecord(
|
||||
episode_index=0,
|
||||
episode_task="test",
|
||||
frame_timestamps=(0.0, 0.1),
|
||||
frame_indices=(0, 1),
|
||||
data_path=Path("/fake/data.parquet"),
|
||||
row_offset=0,
|
||||
row_count=2,
|
||||
)
|
||||
record._frames_df_cache = pd.DataFrame({"episode_index": [0, 0], "frame_index": [0, 1]})
|
||||
|
||||
with pytest.raises(KeyError, match="mc_return"):
|
||||
module.compute_advantages_for_episode(record)
|
||||
@@ -30,6 +30,7 @@ pytest.importorskip("pandas", reason="pandas is required (install lerobot[datase
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
AdvantageConfig,
|
||||
AnnotationPipelineConfig,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
@@ -37,6 +38,7 @@ from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
|
||||
AdvantageModule,
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
@@ -132,6 +134,7 @@ def _build_executor() -> Executor:
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP advantage conditioning recipes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs.recipe import load_recipe
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
|
||||
RECIPES_DIR = Path(__file__).resolve().parents[2] / "src" / "lerobot" / "configs" / "recipes"
|
||||
|
||||
|
||||
def _persistent_rows(advantage: str | None = None):
|
||||
"""Build minimal persistent rows with optional advantage."""
|
||||
rows = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "pick up the cup",
|
||||
"style": "task_aug",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reaching for the cup",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
]
|
||||
if advantage is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": advantage,
|
||||
"style": "advantage",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def test_recap_advantage_recipe_loads():
|
||||
"""The recap_advantage.yaml recipe loads without errors."""
|
||||
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
|
||||
assert recipe.messages is not None
|
||||
assert len(recipe.messages) == 3
|
||||
assert recipe.bindings == {"advantage": "active_at(t, style=advantage)"}
|
||||
|
||||
|
||||
def test_advantage_present_renders_indicator():
|
||||
"""When advantage annotation exists, the prompt includes 'Advantage: positive'."""
|
||||
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=_persistent_rows(advantage="positive"),
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="pick up the cup",
|
||||
)
|
||||
assert result is not None
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 3
|
||||
assert messages[1]["content"] == "Advantage: positive"
|
||||
|
||||
|
||||
def test_advantage_negative_renders_indicator():
|
||||
"""Negative advantage also appears in the prompt."""
|
||||
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=_persistent_rows(advantage="negative"),
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="pick up the cup",
|
||||
)
|
||||
assert result is not None
|
||||
messages = result["messages"]
|
||||
assert messages[1]["content"] == "Advantage: negative"
|
||||
|
||||
|
||||
def test_advantage_absent_skips_turn():
|
||||
"""When no advantage annotation exists (dropout), the advantage turn is skipped."""
|
||||
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=_persistent_rows(advantage=None),
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="pick up the cup",
|
||||
)
|
||||
assert result is not None
|
||||
messages = result["messages"]
|
||||
# Only task + subtask, no advantage turn
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "pick up the cup"
|
||||
assert messages[1]["content"] == "reaching for the cup"
|
||||
|
||||
|
||||
def test_advantage_absent_still_has_target():
|
||||
"""Even without advantage, the target message (subtask) is preserved."""
|
||||
recipe = load_recipe(RECIPES_DIR / "recap_advantage.yaml")
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=_persistent_rows(advantage=None),
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="pick up the cup",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_blend_recipe_loads():
|
||||
"""The blend recipe has two components with correct weights."""
|
||||
recipe = load_recipe(RECIPES_DIR / "recap_advantage_blend.yaml")
|
||||
assert recipe.blend is not None
|
||||
assert "advantage_conditioned" in recipe.blend
|
||||
assert "unconditional" in recipe.blend
|
||||
assert recipe.blend["advantage_conditioned"].weight == 0.7
|
||||
assert recipe.blend["unconditional"].weight == 0.3
|
||||
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Tests for PI05 Classifier-Free Guidance (CFG) inference."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("transformers", reason="transformers is required for PI05")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Config, make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor.converters import create_transition # noqa: E402
|
||||
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402
|
||||
from lerobot.types import TransitionKey # noqa: E402
|
||||
from lerobot.utils.constants import ( # noqa: E402
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_UNCOND_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_UNCOND_TOKENS,
|
||||
)
|
||||
|
||||
|
||||
class TestRenderedMessagesToTaskBaseTaskPreservation:
|
||||
"""Tests that RenderedMessagesToTaskStep preserves base_task for CFG."""
|
||||
|
||||
def test_preserves_string_base_task(self):
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick up the cup",
|
||||
"messages": [
|
||||
{"role": "user", "content": "pick up the cup, Advantage: positive"},
|
||||
],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["base_task"] == "pick up the cup"
|
||||
assert data["task"] == "pick up the cup, Advantage: positive"
|
||||
|
||||
def test_preserves_list_base_task(self):
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": ["task1", "task2"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "rendered with advantage"},
|
||||
],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["base_task"] == ["task1", "task2"]
|
||||
|
||||
def test_no_base_task_when_messages_absent(self):
|
||||
transition = create_transition(complementary_data={"task": "pick up the cup"})
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "base_task" not in data
|
||||
|
||||
|
||||
class TestPi05PrepareStateTokenizerCfg:
|
||||
"""Tests for Pi05PrepareStateTokenizerProcessorStep with cfg_enabled."""
|
||||
|
||||
def _make_transition(self, task, base_task=None):
|
||||
complementary_data = {"task": task}
|
||||
if base_task is not None:
|
||||
complementary_data["base_task"] = base_task
|
||||
return create_transition(
|
||||
observation={"observation.state": torch.zeros(1, 14)},
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
|
||||
def test_cfg_disabled_no_uncond_task(self):
|
||||
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
|
||||
|
||||
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=False)
|
||||
transition = self._make_transition(task=["pick up the cup, Advantage: positive"])
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "uncond_task" not in data
|
||||
|
||||
def test_cfg_enabled_produces_uncond_task_from_base(self):
|
||||
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
|
||||
|
||||
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
|
||||
transition = self._make_transition(
|
||||
task=["pick up the cup, Advantage: positive"],
|
||||
base_task=["pick up the cup"],
|
||||
)
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "uncond_task" in data
|
||||
assert len(data["uncond_task"]) == 1
|
||||
# Unconditional prompt uses base_task (no advantage)
|
||||
assert "Advantage" not in data["uncond_task"][0]
|
||||
assert "pick up the cup" in data["uncond_task"][0]
|
||||
assert "State:" in data["uncond_task"][0]
|
||||
|
||||
def test_cfg_enabled_falls_back_to_task_when_no_base(self):
|
||||
from lerobot.policies.pi05.processor_pi05 import Pi05PrepareStateTokenizerProcessorStep
|
||||
|
||||
step = Pi05PrepareStateTokenizerProcessorStep(max_state_dim=14, cfg_enabled=True)
|
||||
transition = self._make_transition(task=["pick up the cup"])
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Falls back to using task itself as unconditional
|
||||
assert "uncond_task" in data
|
||||
assert "pick up the cup" in data["uncond_task"][0]
|
||||
|
||||
|
||||
class TestCfgPipelineConstruction:
|
||||
"""Tests that the processor pipeline is constructed correctly for CFG."""
|
||||
|
||||
def _make_config(self, cfg_beta=1.0, recipe_path=None):
|
||||
config = PI05Config(
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
cfg_beta=cfg_beta,
|
||||
recipe_path=recipe_path,
|
||||
device="cpu",
|
||||
)
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
return config
|
||||
|
||||
def _make_dataset_stats(self):
|
||||
return {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"min": torch.zeros(14),
|
||||
"max": torch.ones(14),
|
||||
"q01": torch.zeros(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"min": torch.zeros(7),
|
||||
"max": torch.ones(7),
|
||||
"q01": torch.zeros(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
def test_no_uncond_tokenizer_when_cfg_disabled(self):
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
config = self._make_config(cfg_beta=1.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
|
||||
assert len(tokenizer_steps) == 1
|
||||
|
||||
def test_uncond_tokenizer_added_when_cfg_enabled(self):
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
config = self._make_config(cfg_beta=2.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
tokenizer_steps = [s for s in preprocessor.steps if isinstance(s, TokenizerProcessorStep)]
|
||||
assert len(tokenizer_steps) == 2
|
||||
|
||||
uncond_tokenizer = tokenizer_steps[1]
|
||||
assert uncond_tokenizer.task_key == "uncond_task"
|
||||
assert uncond_tokenizer.output_tokens_key == OBS_LANGUAGE_UNCOND_TOKENS
|
||||
assert uncond_tokenizer.output_mask_key == OBS_LANGUAGE_UNCOND_ATTENTION_MASK
|
||||
|
||||
def test_cfg_pipeline_produces_both_token_sets(self):
|
||||
config = self._make_config(cfg_beta=2.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(14),
|
||||
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
processed = preprocessor(batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert OBS_LANGUAGE_UNCOND_TOKENS in processed
|
||||
assert OBS_LANGUAGE_UNCOND_ATTENTION_MASK in processed
|
||||
|
||||
# Both should be tensors with the same shape
|
||||
assert processed[OBS_LANGUAGE_TOKENS].shape == processed[OBS_LANGUAGE_UNCOND_TOKENS].shape
|
||||
assert (
|
||||
processed[OBS_LANGUAGE_ATTENTION_MASK].shape
|
||||
== processed[OBS_LANGUAGE_UNCOND_ATTENTION_MASK].shape
|
||||
)
|
||||
|
||||
def test_cfg_beta_1_no_uncond_tokens_in_output(self):
|
||||
config = self._make_config(cfg_beta=1.0)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config, self._make_dataset_stats())
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(14),
|
||||
"observation.images.base_0_rgb": torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
processed = preprocessor(batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_UNCOND_TOKENS not in processed
|
||||
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Tests for RenderedMessagesToTaskStep and PI05 pipeline integration with advantage."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||
from lerobot.processor.converters import create_transition # noqa: E402
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep # noqa: E402
|
||||
from lerobot.processor.rendered_messages_to_task import RenderedMessagesToTaskStep # noqa: E402
|
||||
from lerobot.types import TransitionKey # noqa: E402
|
||||
|
||||
|
||||
def test_rendered_messages_to_task_noops_without_messages():
|
||||
"""Without messages key, the step is a no-op."""
|
||||
transition = create_transition(complementary_data={"task": "pick up the cup"})
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert data["task"] == "pick up the cup"
|
||||
|
||||
|
||||
def test_rendered_messages_to_task_extracts_user_content():
|
||||
"""Extracts user-role message content and joins with newline."""
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "original task",
|
||||
"messages": [
|
||||
{"role": "user", "content": "pick up the cup"},
|
||||
{"role": "user", "content": "Advantage: positive"},
|
||||
{"role": "assistant", "content": "reach for cup"},
|
||||
],
|
||||
"message_streams": ["high_level", "high_level", "low_level"],
|
||||
"target_message_indices": [2],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["task"] == "pick up the cup\nAdvantage: positive"
|
||||
assert "messages" not in data
|
||||
assert "message_streams" not in data
|
||||
assert "target_message_indices" not in data
|
||||
|
||||
|
||||
def test_rendered_messages_to_task_handles_multimodal_blocks():
|
||||
"""Extracts text from HF multimodal content blocks."""
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "original",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "placeholder"},
|
||||
{"type": "text", "text": "describe this"},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "a cup on a table"},
|
||||
],
|
||||
"message_streams": ["high_level", "low_level"],
|
||||
"target_message_indices": [1],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["task"] == "describe this"
|
||||
|
||||
|
||||
def test_rendered_messages_to_task_preserves_list_task_format():
|
||||
"""When original task is a list (batched), output is also a list."""
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": ["task1", "task2"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "rendered task"},
|
||||
{"role": "assistant", "content": "do it", "target": True},
|
||||
],
|
||||
"message_streams": ["high_level", "low_level"],
|
||||
"target_message_indices": [1],
|
||||
}
|
||||
)
|
||||
step = RenderedMessagesToTaskStep()
|
||||
out = step(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["task"] == ["rendered task", "rendered task"]
|
||||
|
||||
|
||||
def test_full_render_then_flatten_pipeline():
|
||||
"""RenderMessagesStep + RenderedMessagesToTaskStep produces correct task string."""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="Advantage: ${advantage}",
|
||||
stream="high_level",
|
||||
if_present="advantage",
|
||||
),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick up the cup",
|
||||
"timestamp": torch.tensor(0.5),
|
||||
"index": torch.tensor(0),
|
||||
"language_persistent": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach for the cup",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "positive",
|
||||
"style": "advantage",
|
||||
"timestamp": 0.1,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
"language_events": [],
|
||||
}
|
||||
)
|
||||
|
||||
# Step 1: Render recipe
|
||||
rendered = RenderMessagesStep(recipe=recipe)(transition)
|
||||
# Step 2: Flatten to task string
|
||||
out = RenderedMessagesToTaskStep()(rendered)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "pick up the cup" in data["task"]
|
||||
assert "Advantage: positive" in data["task"]
|
||||
|
||||
|
||||
def test_full_render_advantage_absent_skips_turn():
|
||||
"""When advantage row is absent, the advantage turn is skipped via if_present."""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="Advantage: ${advantage}",
|
||||
stream="high_level",
|
||||
if_present="advantage",
|
||||
),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick up the cup",
|
||||
"timestamp": torch.tensor(0.5),
|
||||
"index": torch.tensor(0),
|
||||
"language_persistent": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach for the cup",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
"language_events": [],
|
||||
}
|
||||
)
|
||||
|
||||
rendered = RenderMessagesStep(recipe=recipe)(transition)
|
||||
out = RenderedMessagesToTaskStep()(rendered)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["task"] == "pick up the cup"
|
||||
assert "Advantage" not in data["task"]
|
||||
@@ -0,0 +1,518 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP's distributional value function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_BINS = 201
|
||||
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||
|
||||
|
||||
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||
defaults = {
|
||||
"init_from_actor_path": "",
|
||||
"device": "cpu",
|
||||
"image_resolution": (224, 224),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
config = DistributionalVFConfig(**defaults)
|
||||
config.input_features = {
|
||||
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_model():
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel(_make_config())
|
||||
|
||||
|
||||
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
return {
|
||||
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
def test_config_registered_in_reward_model_registry():
|
||||
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "distributional_value_function" in known
|
||||
|
||||
|
||||
def test_factory_returns_correct_class():
|
||||
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("distributional_value_function")
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
assert cls is DistributionalVFRewardModel
|
||||
|
||||
|
||||
def test_make_reward_model_config_factory():
|
||||
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||
from lerobot.rewards.factory import make_reward_model_config
|
||||
|
||||
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||
assert isinstance(config, DistributionalVFConfig)
|
||||
assert config.num_value_bins == 101
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_sums_to_one():
|
||||
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_non_negative():
|
||||
"""HL-Gauss target probabilities are all non-negative."""
|
||||
model = _make_model()
|
||||
targets = torch.linspace(-1.0, 0.0, 10)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert (dist >= 0).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_expected_value_matches():
|
||||
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_handles_2d_input():
|
||||
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (2, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_sums_to_one():
|
||||
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
assert dist.shape == (5, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_at_most_two_nonzero():
|
||||
"""Dirac delta places probability on at most two adjacent bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.7523, -0.0013])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
for i in range(2):
|
||||
assert (dist[i] > 0).sum() <= 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_expected_value_matches():
|
||||
"""E[V] under Dirac delta distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_boundary_values_clamped():
|
||||
"""Values outside support are clamped to boundary bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-1.5, 0.5])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||
assert dist[0, 0] == 1.0
|
||||
assert dist[1, -1] == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_single_nonzero():
|
||||
"""One-hot target has exactly one non-zero bin per sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
for i in range(4):
|
||||
assert (dist[i] > 0).sum() == 1
|
||||
assert dist[i].sum() == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_nearest_bin():
|
||||
"""One-hot target activates the bin closest to the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
hot_idx = dist[0].argmax()
|
||||
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_terminal_gets_one_hot():
|
||||
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||
is_terminal = torch.tensor([False, True, False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||
assert (dist[1] > 0).sum() == 1
|
||||
assert (dist[3] > 0).sum() == 1
|
||||
assert (dist[0] > 0).sum() > 2
|
||||
assert (dist[2] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_no_terminal_override_when_disabled():
|
||||
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3])
|
||||
is_terminal = torch.tensor([False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||
)
|
||||
|
||||
assert (dist[1] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_has_expected_components():
|
||||
"""Model scaffold contains all architectural components."""
|
||||
model = _make_model()
|
||||
|
||||
assert hasattr(model, "vision_tower")
|
||||
assert hasattr(model, "multi_modal_projector")
|
||||
assert hasattr(model, "token_embedding")
|
||||
assert hasattr(model, "layers")
|
||||
assert hasattr(model, "value_head")
|
||||
assert hasattr(model, "cls_embedding")
|
||||
assert hasattr(model, "norm")
|
||||
assert hasattr(model, "rotary_emb")
|
||||
assert hasattr(model, "bin_centers")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_bin_centers_shape():
|
||||
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||
model = _make_model()
|
||||
assert model.bin_centers.shape == (NUM_BINS,)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_layer_count():
|
||||
"""Transformer has num_hidden_layers (6) layers."""
|
||||
model = _make_model()
|
||||
assert len(model.layers) == 6
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_value_head_output_dim():
|
||||
"""Value head outputs num_value_bins logits."""
|
||||
model = _make_model()
|
||||
assert model.value_head.out_features == NUM_BINS
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_returns_loss_and_dict():
|
||||
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, output_dict = model.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert "loss" in output_dict
|
||||
assert "predicted_value_mean" in output_dict
|
||||
assert "mc_return_mean" in output_dict
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_loss_is_positive():
|
||||
"""Cross-entropy loss is strictly positive for random weights."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_returns_correct_shape():
|
||||
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=3)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert values.shape == (3,)
|
||||
assert values.dtype == torch.float32
|
||||
assert torch.isfinite(values).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_values_in_support_range():
|
||||
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=8)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert (values >= -1.0 - 0.01).all()
|
||||
assert (values <= 0.0 + 0.01).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_processor_pipeline_produces_expected_keys():
|
||||
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
config = _make_config()
|
||||
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||
|
||||
raw_batch = {
|
||||
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
|
||||
processed = preprocessor(raw_batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert IMAGE_KEY in processed
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_value_head():
|
||||
"""Backprop produces non-zero gradients on the value head."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.value_head.weight.grad is not None
|
||||
assert not torch.all(model.value_head.weight.grad == 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_cls_embedding():
|
||||
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.cls_embedding.grad is not None
|
||||
assert not torch.all(model.cls_embedding.grad == 0)
|
||||
|
||||
|
||||
def test_config_requires_visual_feature():
|
||||
"""validate_features raises if no VISUAL feature is present."""
|
||||
config = DistributionalVFConfig(init_from_actor_path="")
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="VISUAL"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_config_passes_with_visual_feature():
|
||||
"""validate_features succeeds when a VISUAL feature is present."""
|
||||
config = _make_config()
|
||||
config.validate_features()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||
"""Saved model can be loaded back with identical weights."""
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
model = _make_model()
|
||||
model._save_pretrained(tmp_path)
|
||||
|
||||
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
orig_sd = model.state_dict()
|
||||
loaded_sd = loaded.state_dict()
|
||||
|
||||
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||
for key in orig_sd:
|
||||
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.min() >= -1.0 - 1e-5
|
||||
assert image.max() <= 1.0 + 1e-5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_resizes_with_pad():
|
||||
"""Image preprocessor resizes non-square images to target resolution."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.shape[1:3] == (224, 224)
|
||||
|
||||
|
||||
def test_task_prompt_formats_correctly():
|
||||
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: pick up the cup."
|
||||
|
||||
|
||||
def test_task_prompt_handles_string_input():
|
||||
"""Task prompt step accepts a plain string (not just a list)."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: open drawer."
|
||||
|
||||
|
||||
def test_task_prompt_raises_on_missing_task():
|
||||
"""Task prompt step raises ValueError when task key is absent."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No task found"):
|
||||
step(transition)
|
||||
@@ -0,0 +1,514 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for lerobot-compute-returns script."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.lerobot_compute_returns import (
|
||||
IS_TERMINAL_COL,
|
||||
MC_RETURN_COL,
|
||||
ComputeReturnsConfig,
|
||||
_get_episode_success,
|
||||
compute_episode_returns,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parquet_dataset(tmp_path):
|
||||
"""Build a minimal parquet shard + info.json for testing I/O logic.
|
||||
|
||||
Mirrors the lerobot-rollout DAgger convention: ``next.success`` is False
|
||||
on all frames except the terminal frame of successful episodes.
|
||||
Even episodes are successful, odd episodes are failures.
|
||||
"""
|
||||
num_episodes = 3
|
||||
frames_per_ep = 10
|
||||
|
||||
root = tmp_path / "test_dataset"
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
meta_dir = root / "meta"
|
||||
data_dir.mkdir(parents=True)
|
||||
meta_dir.mkdir(parents=True)
|
||||
|
||||
all_rows = []
|
||||
episodes_meta = []
|
||||
global_idx = 0
|
||||
for ep in range(num_episodes):
|
||||
ep_from = global_idx
|
||||
is_successful = ep % 2 == 0
|
||||
for frame in range(frames_per_ep):
|
||||
is_last_frame = frame == frames_per_ep - 1
|
||||
all_rows.append(
|
||||
{
|
||||
"episode_index": ep,
|
||||
"frame_index": frame,
|
||||
"index": global_idx,
|
||||
"next.success": is_successful and is_last_frame,
|
||||
}
|
||||
)
|
||||
global_idx += 1
|
||||
ep_to = global_idx
|
||||
episodes_meta.append(
|
||||
{
|
||||
"episode_index": ep,
|
||||
"length": frames_per_ep,
|
||||
"dataset_from_index": ep_from,
|
||||
"dataset_to_index": ep_to,
|
||||
}
|
||||
)
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"episode_index": [r["episode_index"] for r in all_rows],
|
||||
"frame_index": [r["frame_index"] for r in all_rows],
|
||||
"index": [r["index"] for r in all_rows],
|
||||
"next.success": [r["next.success"] for r in all_rows],
|
||||
}
|
||||
)
|
||||
|
||||
parquet_path = data_dir / "episode_000000.parquet"
|
||||
pq.write_table(table, parquet_path)
|
||||
|
||||
info = {
|
||||
"codebase_version": "v3.0",
|
||||
"total_episodes": num_episodes,
|
||||
"total_frames": global_idx,
|
||||
"fps": 30,
|
||||
"features": {
|
||||
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"next.success": {"dtype": "bool", "shape": [1], "names": None},
|
||||
},
|
||||
}
|
||||
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
|
||||
|
||||
return root, parquet_path, episodes_meta
|
||||
|
||||
|
||||
def _rewrite_shard(parquet_path: Path, episodes_meta: list[dict], config: ComputeReturnsConfig):
|
||||
"""Rewrite a single parquet shard using the core logic from compute_returns."""
|
||||
table = pq.read_table(parquet_path)
|
||||
|
||||
if not config.force and IS_TERMINAL_COL in table.column_names:
|
||||
return
|
||||
|
||||
all_is_terminal = np.zeros(len(table), dtype=bool)
|
||||
all_mc_return = np.zeros(len(table), dtype=np.float32)
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
|
||||
for ep_info in episodes_meta:
|
||||
ep_idx = ep_info["episode_index"]
|
||||
ep_len = ep_info["length"]
|
||||
|
||||
mask = np.array([v == ep_idx for v in episode_col], dtype=bool)
|
||||
local_indices = np.where(mask)[0]
|
||||
|
||||
ep_subtable = table.filter(mask)
|
||||
success = _get_episode_success(ep_subtable, config.success_key, config.default_success)
|
||||
|
||||
is_terminal, mc_return = compute_episode_returns(
|
||||
num_frames=ep_len,
|
||||
success=success,
|
||||
c_fail=config.c_fail,
|
||||
gamma=config.gamma,
|
||||
max_episode_length=config.max_episode_length or ep_len,
|
||||
)
|
||||
|
||||
all_is_terminal[local_indices] = is_terminal
|
||||
all_mc_return[local_indices] = mc_return
|
||||
|
||||
if IS_TERMINAL_COL in table.column_names:
|
||||
table = table.drop(IS_TERMINAL_COL)
|
||||
if MC_RETURN_COL in table.column_names:
|
||||
table = table.drop(MC_RETURN_COL)
|
||||
|
||||
table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal))
|
||||
table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return))
|
||||
pq.write_table(table, parquet_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: compute_episode_returns (pure math, no I/O)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_successful_episode_terminal_reward_is_zero():
|
||||
"""Terminal MC return for a successful episode should be 0."""
|
||||
_, mc_return = compute_episode_returns(
|
||||
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10
|
||||
)
|
||||
assert mc_return[-1] == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_failed_episode_terminal_reward_reflects_cfail():
|
||||
"""Terminal MC return for a failed episode should be -C_fail / H."""
|
||||
horizon = 100
|
||||
c_fail = 50.0
|
||||
_, mc_return = compute_episode_returns(
|
||||
num_frames=10, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
|
||||
)
|
||||
assert mc_return[-1] == pytest.approx(-c_fail / horizon, abs=1e-5)
|
||||
|
||||
|
||||
def test_is_terminal_only_last_frame():
|
||||
"""Only the last frame of an episode should be marked terminal."""
|
||||
is_terminal, _ = compute_episode_returns(
|
||||
num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20
|
||||
)
|
||||
assert is_terminal[-1] == True # noqa: E712
|
||||
assert not any(is_terminal[:-1])
|
||||
|
||||
|
||||
def test_mc_return_monotonically_increases_for_success():
|
||||
"""For a successful undiscounted episode, returns should increase toward 0."""
|
||||
_, mc_return = compute_episode_returns(
|
||||
num_frames=50, success=True, c_fail=50.0, gamma=1.0, max_episode_length=50
|
||||
)
|
||||
for i in range(len(mc_return) - 1):
|
||||
assert mc_return[i] <= mc_return[i + 1]
|
||||
|
||||
|
||||
def test_mc_return_bounded_negative_to_zero():
|
||||
"""MC returns for successful episodes should be in (-1, 0]."""
|
||||
_, mc_return = compute_episode_returns(
|
||||
num_frames=100, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100
|
||||
)
|
||||
assert mc_return[-1] == pytest.approx(0.0, abs=1e-6)
|
||||
assert all(v <= 0.0 for v in mc_return)
|
||||
assert all(v >= -1.0 - 1e-6 for v in mc_return)
|
||||
|
||||
|
||||
def test_first_frame_return_success():
|
||||
"""First frame return for successful episode equals -(N-1)/H."""
|
||||
num_frames = 10
|
||||
horizon = 10
|
||||
_, mc_return = compute_episode_returns(
|
||||
num_frames=num_frames, success=True, c_fail=50.0, gamma=1.0, max_episode_length=horizon
|
||||
)
|
||||
expected = -(num_frames - 1) / horizon
|
||||
assert mc_return[0] == pytest.approx(expected, abs=1e-5)
|
||||
|
||||
|
||||
def test_first_frame_return_failure():
|
||||
"""First frame return for failed episode includes the failure penalty."""
|
||||
num_frames = 10
|
||||
horizon = 100
|
||||
c_fail = 50.0
|
||||
_, mc_return = compute_episode_returns(
|
||||
num_frames=num_frames, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
|
||||
)
|
||||
expected = (-(num_frames - 1) / horizon) + (-c_fail / horizon)
|
||||
assert mc_return[0] == pytest.approx(expected, abs=1e-5)
|
||||
|
||||
|
||||
def test_discount_factor_less_than_one():
|
||||
"""Discount factor < 1 should make earlier frames have smaller magnitude."""
|
||||
_, mc_undiscounted = compute_episode_returns(
|
||||
num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20
|
||||
)
|
||||
_, mc_discounted = compute_episode_returns(
|
||||
num_frames=20, success=True, c_fail=50.0, gamma=0.99, max_episode_length=20
|
||||
)
|
||||
assert abs(mc_discounted[0]) < abs(mc_undiscounted[0])
|
||||
|
||||
|
||||
def test_single_frame_episode_success():
|
||||
"""Single-frame successful episode: return should be 0."""
|
||||
is_terminal, mc_return = compute_episode_returns(
|
||||
num_frames=1, success=True, c_fail=50.0, gamma=1.0, max_episode_length=1
|
||||
)
|
||||
assert mc_return[0] == pytest.approx(0.0, abs=1e-6)
|
||||
assert is_terminal[0] == True # noqa: E712
|
||||
|
||||
|
||||
def test_single_frame_episode_failure():
|
||||
"""Single-frame failed episode: return should be -C_fail/H."""
|
||||
horizon = 100
|
||||
c_fail = 50.0
|
||||
is_terminal, mc_return = compute_episode_returns(
|
||||
num_frames=1, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
|
||||
)
|
||||
assert mc_return[0] == pytest.approx(-c_fail / horizon, abs=1e-5)
|
||||
assert is_terminal[0] == True # noqa: E712
|
||||
|
||||
|
||||
def test_horizon_normalization_scales_returns():
|
||||
"""Larger horizon should scale down the per-step penalty."""
|
||||
_, mc_small_h = compute_episode_returns(
|
||||
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10
|
||||
)
|
||||
_, mc_large_h = compute_episode_returns(
|
||||
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100
|
||||
)
|
||||
assert abs(mc_large_h[0]) < abs(mc_small_h[0])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _get_episode_success (in-memory PyArrow tables)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_success_overrides_column():
|
||||
"""default_success should override any column value."""
|
||||
table = pa.table({"next.success": [True, True, True]})
|
||||
assert _get_episode_success(table, "next.success", default_success=False) is False
|
||||
|
||||
|
||||
def test_reads_bool_column():
|
||||
"""Should detect success via any() reduction over the column."""
|
||||
table_success = pa.table({"next.success": [False, False, True]})
|
||||
table_fail = pa.table({"next.success": [False, False, False]})
|
||||
assert _get_episode_success(table_success, "next.success", None) is True
|
||||
assert _get_episode_success(table_fail, "next.success", None) is False
|
||||
|
||||
|
||||
def test_reads_int_column():
|
||||
"""Should interpret integer success column (0/1) as bool via any()."""
|
||||
table = pa.table({"task_success": [0, 0, 1]})
|
||||
assert _get_episode_success(table, "task_success", None) is True
|
||||
|
||||
|
||||
def test_all_zeros_means_failure():
|
||||
"""An episode with all-zero success values is a failure."""
|
||||
table = pa.table({"next.success": [0, 0, 0]})
|
||||
assert _get_episode_success(table, "next.success", None) is False
|
||||
|
||||
|
||||
def test_missing_column_defaults_to_true():
|
||||
"""When success column is missing, assume success (demo data)."""
|
||||
table = pa.table({"frame_index": [0, 1, 2]})
|
||||
assert _get_episode_success(table, "next.success", None) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: parquet rewriting (integration, writes to disk)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_writes_columns_to_parquet(parquet_dataset):
|
||||
"""The rewrite logic should add is_terminal and mc_return columns."""
|
||||
root, parquet_path, episodes_meta = parquet_dataset
|
||||
|
||||
table_before = pq.read_table(parquet_path)
|
||||
assert IS_TERMINAL_COL not in table_before.column_names
|
||||
assert MC_RETURN_COL not in table_before.column_names
|
||||
|
||||
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
|
||||
table_after = pq.read_table(parquet_path)
|
||||
assert IS_TERMINAL_COL in table_after.column_names
|
||||
assert MC_RETURN_COL in table_after.column_names
|
||||
|
||||
|
||||
def test_terminal_frames_correct(parquet_dataset):
|
||||
"""Only the last frame of each episode should be terminal."""
|
||||
root, parquet_path, episodes_meta = parquet_dataset
|
||||
|
||||
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
|
||||
table = pq.read_table(parquet_path)
|
||||
is_terminal = table.column(IS_TERMINAL_COL).to_pylist()
|
||||
terminal_indices = [i for i, v in enumerate(is_terminal) if v]
|
||||
assert terminal_indices == [9, 19, 29]
|
||||
|
||||
|
||||
def test_success_episodes_return_zero_at_terminal(tmp_path):
|
||||
"""Successful episodes (ep 0) should have mc_return=0 at terminal."""
|
||||
num_episodes = 2
|
||||
frames_per_ep = 5
|
||||
|
||||
root = tmp_path / "test_dataset"
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
meta_dir = root / "meta"
|
||||
data_dir.mkdir(parents=True)
|
||||
meta_dir.mkdir(parents=True)
|
||||
|
||||
all_rows = []
|
||||
episodes_meta = []
|
||||
global_idx = 0
|
||||
for ep in range(num_episodes):
|
||||
ep_from = global_idx
|
||||
is_successful = ep % 2 == 0
|
||||
for frame in range(frames_per_ep):
|
||||
is_last_frame = frame == frames_per_ep - 1
|
||||
all_rows.append(
|
||||
{
|
||||
"episode_index": ep,
|
||||
"frame_index": frame,
|
||||
"index": global_idx,
|
||||
"next.success": is_successful and is_last_frame,
|
||||
}
|
||||
)
|
||||
global_idx += 1
|
||||
episodes_meta.append(
|
||||
{
|
||||
"episode_index": ep,
|
||||
"length": frames_per_ep,
|
||||
"dataset_from_index": ep_from,
|
||||
"dataset_to_index": global_idx,
|
||||
}
|
||||
)
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"episode_index": [r["episode_index"] for r in all_rows],
|
||||
"frame_index": [r["frame_index"] for r in all_rows],
|
||||
"index": [r["index"] for r in all_rows],
|
||||
"next.success": [r["next.success"] for r in all_rows],
|
||||
}
|
||||
)
|
||||
parquet_path = data_dir / "episode_000000.parquet"
|
||||
pq.write_table(table, parquet_path)
|
||||
|
||||
info = {
|
||||
"codebase_version": "v3.0",
|
||||
"total_episodes": num_episodes,
|
||||
"total_frames": global_idx,
|
||||
"fps": 30,
|
||||
"features": {
|
||||
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"next.success": {"dtype": "bool", "shape": [1], "names": None},
|
||||
},
|
||||
}
|
||||
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
|
||||
|
||||
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, force=True)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
|
||||
table = pq.read_table(parquet_path)
|
||||
mc_return = table.column(MC_RETURN_COL).to_pylist()
|
||||
assert mc_return[4] == pytest.approx(0.0, abs=1e-5)
|
||||
|
||||
|
||||
def test_failed_episodes_have_negative_terminal(tmp_path):
|
||||
"""Failed episodes (ep 1) should have mc_return < 0 at terminal."""
|
||||
num_episodes = 2
|
||||
frames_per_ep = 5
|
||||
|
||||
root = tmp_path / "test_dataset"
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
meta_dir = root / "meta"
|
||||
data_dir.mkdir(parents=True)
|
||||
meta_dir.mkdir(parents=True)
|
||||
|
||||
all_rows = []
|
||||
episodes_meta = []
|
||||
global_idx = 0
|
||||
for ep in range(num_episodes):
|
||||
ep_from = global_idx
|
||||
is_successful = ep % 2 == 0
|
||||
for frame in range(frames_per_ep):
|
||||
is_last_frame = frame == frames_per_ep - 1
|
||||
all_rows.append(
|
||||
{
|
||||
"episode_index": ep,
|
||||
"frame_index": frame,
|
||||
"index": global_idx,
|
||||
"next.success": is_successful and is_last_frame,
|
||||
}
|
||||
)
|
||||
global_idx += 1
|
||||
episodes_meta.append(
|
||||
{
|
||||
"episode_index": ep,
|
||||
"length": frames_per_ep,
|
||||
"dataset_from_index": ep_from,
|
||||
"dataset_to_index": global_idx,
|
||||
}
|
||||
)
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"episode_index": [r["episode_index"] for r in all_rows],
|
||||
"frame_index": [r["frame_index"] for r in all_rows],
|
||||
"index": [r["index"] for r in all_rows],
|
||||
"next.success": [r["next.success"] for r in all_rows],
|
||||
}
|
||||
)
|
||||
parquet_path = data_dir / "episode_000000.parquet"
|
||||
pq.write_table(table, parquet_path)
|
||||
|
||||
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, c_fail=50.0, force=True)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
|
||||
table = pq.read_table(parquet_path)
|
||||
mc_return = table.column(MC_RETURN_COL).to_pylist()
|
||||
assert mc_return[9] < 0.0
|
||||
|
||||
|
||||
def test_idempotent_with_force_flag(parquet_dataset):
|
||||
"""Running twice with force should produce identical results."""
|
||||
root, parquet_path, episodes_meta = parquet_dataset
|
||||
|
||||
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
table1 = pq.read_table(parquet_path)
|
||||
mc1 = table1.column(MC_RETURN_COL).to_pylist()
|
||||
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
table2 = pq.read_table(parquet_path)
|
||||
mc2 = table2.column(MC_RETURN_COL).to_pylist()
|
||||
|
||||
assert mc1 == mc2
|
||||
|
||||
|
||||
def test_skips_if_columns_exist_without_force(parquet_dataset):
|
||||
"""Without force, existing columns should not be overwritten."""
|
||||
root, parquet_path, episodes_meta = parquet_dataset
|
||||
|
||||
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config)
|
||||
|
||||
table = pq.read_table(parquet_path)
|
||||
original_mc = table.column(MC_RETURN_COL).to_pylist()
|
||||
|
||||
config_no_force = ComputeReturnsConfig(success_key="next.success", max_episode_length=20, force=False)
|
||||
_rewrite_shard(parquet_path, episodes_meta, config_no_force)
|
||||
|
||||
table2 = pq.read_table(parquet_path)
|
||||
assert table2.column(MC_RETURN_COL).to_pylist() == original_mc
|
||||
|
||||
|
||||
def test_updates_info_json(parquet_dataset):
|
||||
"""info.json should be updated with is_terminal and mc_return features."""
|
||||
from lerobot.scripts.lerobot_compute_returns import _update_info_json
|
||||
|
||||
root, parquet_path, episodes_meta = parquet_dataset
|
||||
|
||||
_update_info_json(root, None)
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
info = json.loads(info_path.read_text())
|
||||
assert IS_TERMINAL_COL in info["features"]
|
||||
assert MC_RETURN_COL in info["features"]
|
||||
assert info["features"][IS_TERMINAL_COL]["dtype"] == "bool"
|
||||
assert info["features"][MC_RETURN_COL]["dtype"] == "float32"
|
||||
@@ -338,6 +338,103 @@ def test_dagger_events_reset():
|
||||
assert not events.upload_requested.is_set()
|
||||
|
||||
|
||||
def test_dagger_mark_success():
|
||||
"""mark_success sets the episode label to True."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
assert events.consume_episode_success() is None
|
||||
|
||||
events.mark_success()
|
||||
assert events.consume_episode_success() is True
|
||||
# Consuming clears the label
|
||||
assert events.consume_episode_success() is None
|
||||
|
||||
|
||||
def test_dagger_mark_failure():
|
||||
"""mark_failure sets the episode label to False."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
events.mark_failure()
|
||||
assert events.consume_episode_success() is False
|
||||
|
||||
|
||||
def test_dagger_success_overrides_failure():
|
||||
"""Last label wins — success after failure overrides."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
events.mark_failure()
|
||||
events.mark_success()
|
||||
assert events.consume_episode_success() is True
|
||||
|
||||
|
||||
def test_dagger_reset_clears_success_label():
|
||||
"""reset() clears any pending episode success label."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
events.mark_success()
|
||||
events.reset()
|
||||
assert events.consume_episode_success() is None
|
||||
|
||||
|
||||
def test_stamp_episode_success_labels_terminal_frame():
|
||||
"""_stamp_episode_success sets last frame's next.success to True."""
|
||||
import numpy as np
|
||||
|
||||
from lerobot.rollout.strategies.dagger import DAggerStrategy
|
||||
|
||||
strategy = DAggerStrategy.__new__(DAggerStrategy)
|
||||
strategy.config = MagicMock()
|
||||
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
strategy._events = DAggerEvents()
|
||||
strategy._events.mark_success()
|
||||
|
||||
dataset = MagicMock()
|
||||
dataset.writer.episode_buffer = {
|
||||
"next.success": [
|
||||
np.array([False], dtype=bool),
|
||||
np.array([False], dtype=bool),
|
||||
np.array([False], dtype=bool),
|
||||
],
|
||||
}
|
||||
|
||||
strategy._stamp_episode_success(dataset)
|
||||
|
||||
assert dataset.writer.episode_buffer["next.success"][-1].item() is True
|
||||
assert dataset.writer.episode_buffer["next.success"][0].item() is False
|
||||
|
||||
|
||||
def test_stamp_episode_success_no_label_stays_false():
|
||||
"""Without a label, all frames remain False."""
|
||||
import numpy as np
|
||||
|
||||
from lerobot.rollout.strategies.dagger import DAggerStrategy
|
||||
|
||||
strategy = DAggerStrategy.__new__(DAggerStrategy)
|
||||
strategy.config = MagicMock()
|
||||
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
strategy._events = DAggerEvents()
|
||||
|
||||
dataset = MagicMock()
|
||||
dataset.writer.episode_buffer = {
|
||||
"next.success": [
|
||||
np.array([False], dtype=bool),
|
||||
np.array([False], dtype=bool),
|
||||
],
|
||||
}
|
||||
|
||||
strategy._stamp_episode_success(dataset)
|
||||
|
||||
assert all(v.item() is False for v in dataset.writer.episode_buffer["next.success"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user