feat(recap): add advantage scoring annotation module

Implement the RECAP advantage scoring module as a new phase in
lerobot-annotate. Uses a frozen distributional VF to compute per-frame
advantages, binarizes into positive/negative indicators with per-task
threshold, and writes style=advantage persistent rows for policy
conditioning. Skips VF inference on intervention frames as an optimization.
This commit is contained in:
Khalil Meftah
2026-06-22 14:01:58 +02:00
parent e5c94c732f
commit ea908c0672
11 changed files with 632 additions and 7 deletions
@@ -169,6 +169,43 @@ class ExecutorConfig:
episode_parallelism: int = 16
@dataclass
class AdvantageConfig:
"""``advantage`` module: RECAP advantage scoring via frozen value function."""
enabled: bool = True
# Path or Hub repo ID of the trained distributional value function checkpoint.
value_function_path: str = ""
# Device to run the value function on.
device: str = "cuda"
# N-step lookahead for advantage estimation.
# None = MC (N=T): A_t = R_t - V(s_t), using mc_return from dataset.
# 50 = fine-tuning mode: A_t = Σ r_{t:t+N} + V(s_{t+N}) - V(s_t).
n_step: int | None = None
# Per-task percentile for binarization threshold ε_.
# Actions with advantage > ε_ get I_t = True (positive).
threshold_percentile: float = 0.3
# Fraction of frames to randomly omit advantage labels (enables CFG).
dropout_rate: float = 0.3
# Force I_t = True for frames marked as human interventions.
force_positive_on_intervention: bool = True
# Column name in dataset for intervention flag.
intervention_key: str = "intervention"
# Column name for pre-computed MC returns (from lerobot-compute-returns).
mc_return_key: str = "mc_return"
# Batch size for value function inference.
batch_size: int = 32
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
@@ -190,6 +227,7 @@ class AnnotationPipelineConfig:
plan: PlanConfig = field(default_factory=PlanConfig)
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
vqa: VqaConfig = field(default_factory=VqaConfig)
advantage: AdvantageConfig = field(default_factory=AdvantageConfig)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
@@ -15,20 +15,24 @@
# limitations under the License.
"""In-process executor that runs the annotation phases.
The executor runs **six phases** in dependency order:
The executor runs **seven phases** in dependency order:
phase 1: ``plan`` module (plan + subtasks + memory)
phase 2: ``interjections`` module (interjections + speech)
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: ``vqa`` module (VQA)
phase 5: validator
phase 6: writer
phase 5: ``advantage`` module (advantage scoring via frozen VF)
phase 6: validator
phase 7: writer
Phase 3 is why the ``plan`` module must be re-entered after the
``interjections`` module — to refresh ``plan`` rows at interjection
timestamps.
Phase 5 (advantage) does not depend on the VLM modules, it uses a frozen
distributional value function to compute per-frame advantage indicators.
Distributed execution is provided by Hugging Face Jobs (see
``examples/annotations/run_hf_job.py``); the runner inside the job
invokes ``lerobot-annotate`` which uses this in-process executor.
@@ -74,7 +78,7 @@ class PipelineRunSummary:
@dataclass
class Executor:
"""Run all six phases over a dataset root in-process.
"""Run all seven phases over a dataset root in-process.
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
(a thread pool); cluster-level concurrency comes from running this
@@ -86,6 +90,7 @@ class Executor:
plan: Any # PlanSubtasksMemoryModule
interjections: Any # InterjectionsAndSpeechModule
vqa: Any # GeneralVqaModule
advantage: Any # AdvantageModule
writer: LanguageColumnsWriter
validator: StagingValidator
@@ -112,6 +117,8 @@ class Executor:
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: ``vqa`` module (VQA)
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
# Phase 5: ``advantage`` module (advantage scoring via frozen VF)
phases.append(self._run_module_phase("advantage", records, staging_dir, self.advantage))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir)
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .advantage import AdvantageModule
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"AdvantageModule",
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
@@ -0,0 +1,263 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Advantage scoring module for RECAP.
Computes per-frame advantage values using a frozen distributional value function,
binarizes them into improvement indicators (I_t), and emits ``style="advantage"``
persistent rows for policy conditioning.
Paper reference: pi*0.6, Section IV-B and Appendix F.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from ..config import AdvantageConfig
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class AdvantageModule:
"""Compute advantage indicators and emit persistent annotation rows.
The module loads a frozen distributional value function and scores each
frame in an episode. Advantages are binarized into ``positive``/``negative``
indicators using a per-task threshold, then written as ``style="advantage"``
persistent rows into the staging area.
Requires ``mc_return`` column in the dataset (from lerobot-compute-returns).
"""
config: AdvantageConfig
_model: Any = field(default=None, init=False, repr=False)
_preprocessor: Any = field(default=None, init=False, repr=False)
_threshold: float | None = field(default=None, init=False, repr=False)
@property
def enabled(self) -> bool:
return self.config.enabled
def _ensure_model_loaded(self) -> None:
"""Lazy-load the frozen value function on first use."""
if self._model is not None:
return
from lerobot.rewards import (
make_reward_model,
make_reward_model_config,
make_reward_pre_post_processors,
)
cfg = make_reward_model_config(
"distributional_value_function",
pretrained_path=self.config.value_function_path,
device=self.config.device,
)
self._model = make_reward_model(cfg)
self._model.eval()
for p in self._model.parameters():
p.requires_grad_(False)
self._preprocessor, _ = make_reward_pre_post_processors(cfg)
logger.info("Loaded frozen VF from %s on %s", self.config.value_function_path, self.config.device)
def compute_advantages_for_episode(self, record: EpisodeRecord) -> tuple[np.ndarray, np.ndarray]:
"""Compute raw advantage values for all frames in an episode.
Returns:
(advantages, intervention_mask) both shape [num_frames].
advantages[t] = A_t, intervention_mask[t] = True if frame is intervention.
"""
self._ensure_model_loaded()
df = record.frames_df()
num_frames = len(df)
mc_return_key = self.config.mc_return_key
if mc_return_key not in df.columns:
raise KeyError(
f"Column '{mc_return_key}' not found in episode {record.episode_index}. "
"Run lerobot-compute-returns first."
)
mc_returns = df[mc_return_key].values.astype(np.float32)
intervention_mask = np.zeros(num_frames, dtype=bool)
if self.config.intervention_key in df.columns:
intervention_mask = df[self.config.intervention_key].values.astype(bool)
# Skip VF inference on intervention frames — they're always "positive"
# regardless of advantage value, so V(s_t) is never used for them.
skip_mask = intervention_mask if self.config.force_positive_on_intervention else None
values = self._compute_values(record, skip_mask=skip_mask)
if self.config.n_step is None:
advantages = mc_returns - values
else:
advantages = self._compute_n_step_advantages(mc_returns, values, record, n=self.config.n_step)
return advantages, intervention_mask
def _compute_values(self, record: EpisodeRecord, skip_mask: np.ndarray | None = None) -> np.ndarray:
"""Run frozen VF over all frames to get V(s_t) predictions.
Args:
record: Episode data.
skip_mask: Optional boolean mask [num_frames]. Frames where True are
skipped (left as 0.0) to avoid unnecessary inference.
"""
df = record.frames_df()
num_frames = len(df)
values = np.zeros(num_frames, dtype=np.float32)
image_key = self._resolve_image_key(df)
if image_key is None:
logger.warning("No image key found for episode %d; returning zero values.", record.episode_index)
return values
# Determine which frame indices actually need inference
infer_indices = np.where(~skip_mask)[0] if skip_mask is not None else np.arange(num_frames)
if len(infer_indices) == 0:
return values
task_text = record.episode_task
for batch_start in range(0, len(infer_indices), self.config.batch_size):
batch_end = min(batch_start + self.config.batch_size, len(infer_indices))
batch_indices = infer_indices[batch_start:batch_end]
batch_images = []
for idx in batch_indices:
img_val = df.iloc[idx][image_key]
if isinstance(img_val, np.ndarray):
img_tensor = torch.from_numpy(img_val).float()
elif isinstance(img_val, torch.Tensor):
img_tensor = img_val.float()
else:
img_tensor = torch.zeros(3, 224, 224)
batch_images.append(img_tensor)
batch_images_tensor = torch.stack(batch_images)
batch_size = batch_images_tensor.shape[0]
raw_batch = {
image_key: batch_images_tensor,
"task": [task_text] * batch_size,
}
processed = self._preprocessor(raw_batch)
with torch.no_grad():
v_values = self._model.compute_reward(processed)
values[batch_indices] = v_values.cpu().numpy()
return values
def _compute_n_step_advantages(
self, mc_returns: np.ndarray, values: np.ndarray, record: EpisodeRecord, n: int
) -> np.ndarray:
"""Compute N-step advantage: A_t = Σ r_{t:t+N-1} + V(s_{t+N}) - V(s_t).
When t+N exceeds episode length, truncates to MC (uses mc_return directly).
"""
num_frames = len(values)
advantages = np.zeros(num_frames, dtype=np.float32)
for t in range(num_frames):
if t + n >= num_frames:
advantages[t] = mc_returns[t] - values[t]
else:
n_step_return = mc_returns[t] - mc_returns[t + n]
advantages[t] = n_step_return + values[t + n] - values[t]
return advantages
def _resolve_image_key(self, df) -> str | None:
"""Find the first image observation key in the dataframe columns."""
for col in df.columns:
if col.startswith("observation.images."):
return col
return None
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
"""Score one episode and write advantage rows to staging."""
if not self.config.value_function_path:
logger.warning("No value_function_path configured; skipping advantage scoring.")
return
advantages, intervention_mask = self.compute_advantages_for_episode(record)
num_frames = len(advantages)
threshold = self._compute_threshold(advantages, intervention_mask)
rng = np.random.default_rng(seed=hash((record.episode_index, 42)) & 0xFFFFFFFF)
rows: list[dict[str, Any]] = []
for t in range(num_frames):
if rng.random() < self.config.dropout_rate:
continue
if (
self.config.force_positive_on_intervention
and intervention_mask[t]
or advantages[t] > threshold
):
indicator = "positive"
else:
indicator = "negative"
timestamp = float(record.frame_timestamps[t]) if t < len(record.frame_timestamps) else 0.0
rows.append(
{
"role": "user",
"content": indicator,
"style": "advantage",
"timestamp": timestamp,
"camera": None,
"tool_calls": None,
}
)
staging.write("advantage", rows)
logger.debug(
"Episode %d: %d/%d frames scored (threshold=%.4f, %d positive, %d negative)",
record.episode_index,
len(rows),
num_frames,
threshold,
sum(1 for r in rows if r["content"] == "positive"),
sum(1 for r in rows if r["content"] == "negative"),
)
def _compute_threshold(self, advantages: np.ndarray, intervention_mask: np.ndarray) -> float:
"""Compute the binarization threshold as the configured percentile of advantages."""
non_intervention = advantages[~intervention_mask] if intervention_mask.any() else advantages
if len(non_intervention) == 0:
return 0.0
return float(np.percentile(non_intervention, self.config.threshold_percentile * 100))
@@ -39,6 +39,7 @@ _MODULES: tuple[ModuleName, ...] = (
"plan",
"interjections",
"vqa",
"advantage",
)
+1
View File
@@ -32,6 +32,7 @@ DEFAULT_BINDINGS = {
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
"advantage": "active_at(t, style=advantage)",
}
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
+2 -2
View File
@@ -43,10 +43,10 @@ CORE_STYLES = {
# validation. Empty by default — populate from a downstream module that
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
# the new style's column.
EXTENDED_STYLES: set[str] = set()
EXTENDED_STYLES: set[str] = {"advantage"}
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug", "advantage"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
+3
View File
@@ -34,6 +34,7 @@ from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConf
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
from lerobot.annotations.steerable_pipeline.modules import (
AdvantageModule,
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
@@ -86,6 +87,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
)
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
advantage = AdvantageModule(config=cfg.advantage)
writer = LanguageColumnsWriter()
validator = StagingValidator(
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
@@ -96,6 +98,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
plan=plan,
interjections=interjections,
vqa=vqa,
advantage=advantage,
writer=writer,
validator=validator,
)
+3 -1
View File
@@ -28,9 +28,10 @@ import sys
import tempfile
from pathlib import Path
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig, AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
AdvantageModule,
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
@@ -85,6 +86,7 @@ def main() -> int:
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
+305
View File
@@ -0,0 +1,305 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the advantage scoring annotation module."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig
from lerobot.annotations.steerable_pipeline.modules.advantage import AdvantageModule
from lerobot.annotations.steerable_pipeline.reader import EpisodeRecord
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
def _make_record(
episode_index: int = 0,
num_frames: int = 20,
task: str = "pick up the cup",
mc_returns: np.ndarray | None = None,
intervention_mask: np.ndarray | None = None,
fps: float = 10.0,
) -> EpisodeRecord:
"""Build a minimal EpisodeRecord with a mocked frames_df."""
import pandas as pd
timestamps = tuple(round(i / fps, 6) for i in range(num_frames))
frame_indices = tuple(range(num_frames))
if mc_returns is None:
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
data = {
"episode_index": [episode_index] * num_frames,
"frame_index": list(range(num_frames)),
"timestamp": list(timestamps),
"mc_return": mc_returns,
}
if intervention_mask is not None:
data["intervention"] = intervention_mask.astype(bool)
df = pd.DataFrame(data)
record = EpisodeRecord(
episode_index=episode_index,
episode_task=task,
frame_timestamps=timestamps,
frame_indices=frame_indices,
data_path=Path("/fake/data.parquet"),
row_offset=0,
row_count=num_frames,
)
record._frames_df_cache = df
return record
@pytest.fixture
def staging(tmp_path: Path) -> EpisodeStaging:
return EpisodeStaging(tmp_path, episode_index=0)
def test_advantage_module_disabled():
"""Disabled module has enabled=False."""
cfg = AdvantageConfig(enabled=False)
module = AdvantageModule(config=cfg)
assert not module.enabled
def test_advantage_module_enabled_by_default():
"""Module is enabled by default."""
cfg = AdvantageConfig()
module = AdvantageModule(config=cfg)
assert module.enabled
def test_run_episode_skips_without_value_function_path(staging: EpisodeStaging):
"""Module gracefully returns when no value_function_path is configured."""
cfg = AdvantageConfig(value_function_path="")
module = AdvantageModule(config=cfg)
record = _make_record()
module.run_episode(record, staging)
rows = staging.read("advantage")
assert rows == []
def test_binarization_with_mock_values(staging: EpisodeStaging):
"""Advantage binarization produces positive/negative labels based on threshold."""
num_frames = 10
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1, -0.5, -0.6, -0.7, -0.8, -0.9], dtype=np.float32)
mock_values = np.array([-0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4], dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.0,
threshold_percentile=0.5,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
assert len(rows) == num_frames
# A_t = mc_returns - values
# advantages = [-0.1, 0.0, 0.1, 0.2, 0.3, -0.1, -0.2, -0.3, -0.4, -0.5]
# Median (50th pctile) = -0.1
# positive: advantage > -0.1 → indices 1,2,3,4
# negative: advantage <= -0.1 → indices 0,5,6,7,8,9
positives = [r for r in rows if r["content"] == "positive"]
negatives = [r for r in rows if r["content"] == "negative"]
assert len(positives) == 4
assert len(negatives) == 6
def test_intervention_frames_forced_positive(staging: EpisodeStaging):
"""Intervention frames are always scored as positive regardless of advantage value."""
num_frames = 5
mc_returns = np.array([-0.9, -0.9, -0.9, -0.9, -0.9], dtype=np.float32)
mock_values = np.array([-0.1, -0.1, -0.1, -0.1, -0.1], dtype=np.float32)
intervention = np.array([False, False, True, False, False])
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.0,
force_positive_on_intervention=True,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns, intervention_mask=intervention)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
# Frame 2 (intervention) should be positive despite negative advantage
assert rows[2]["content"] == "positive"
def test_dropout_reduces_output_rows(staging: EpisodeStaging):
"""Non-zero dropout rate omits some frames."""
num_frames = 100
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
mock_values = np.full(num_frames, -0.5, dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.3,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
# With 30% dropout on 100 frames, expect ~70 rows (with some variance)
assert 50 < len(rows) < 90
def test_staged_row_format(staging: EpisodeStaging):
"""Staged rows have the correct schema for language_persistent."""
num_frames = 5
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1], dtype=np.float32)
mock_values = np.full(5, -0.3, dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
dropout_rate=0.0,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with (
patch.object(module, "_ensure_model_loaded"),
patch.object(module, "_compute_values", return_value=mock_values),
):
module.run_episode(record, staging)
rows = staging.read("advantage")
for row in rows:
assert row["role"] == "user"
assert row["content"] in ("positive", "negative")
assert row["style"] == "advantage"
assert isinstance(row["timestamp"], float)
assert row["camera"] is None
assert row["tool_calls"] is None
def test_n_step_advantage():
"""N-step advantage uses partial returns + bootstrapped value."""
num_frames = 10
mc_returns = np.linspace(-0.9, 0.0, num_frames).astype(np.float32)
mock_values = np.full(num_frames, -0.45, dtype=np.float32)
cfg = AdvantageConfig(
value_function_path="/fake/vf",
n_step=3,
dropout_rate=0.0,
)
module = AdvantageModule(config=cfg)
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
with patch.object(module, "_ensure_model_loaded"):
advantages, _ = (
module.compute_advantages_for_episode.__wrapped__(module, record)
if hasattr(module.compute_advantages_for_episode, "__wrapped__")
else (None, None)
)
# Just verify computation works - use the internal method directly
module._model = MagicMock()
module._preprocessor = MagicMock()
with patch.object(module, "_compute_values", return_value=mock_values):
advantages, _ = module.compute_advantages_for_episode(record)
# For t where t+n < num_frames: A = mc_return[t] - mc_return[t+n] + values[t+n] - values[t]
# Since values are constant: A = mc_return[t] - mc_return[t+n]
# For t where t+n >= num_frames: A = mc_return[t] - values[t]
for t in range(num_frames):
if t + 3 < num_frames:
expected = mc_returns[t] - mc_returns[t + 3] + mock_values[t + 3] - mock_values[t]
else:
expected = mc_returns[t] - mock_values[t]
np.testing.assert_almost_equal(advantages[t], expected, decimal=5)
def test_compute_threshold():
"""Threshold is computed as configured percentile of non-intervention advantages."""
cfg = AdvantageConfig(threshold_percentile=0.3)
module = AdvantageModule(config=cfg)
advantages = np.array([-1.0, -0.5, 0.0, 0.5, 1.0], dtype=np.float32)
intervention_mask = np.array([False, False, False, False, False])
threshold = module._compute_threshold(advantages, intervention_mask)
expected = float(np.percentile(advantages, 30))
assert abs(threshold - expected) < 1e-6
def test_compute_threshold_excludes_intervention():
"""Threshold computation excludes intervention frames."""
cfg = AdvantageConfig(threshold_percentile=0.5)
module = AdvantageModule(config=cfg)
advantages = np.array([100.0, -1.0, 0.0, 1.0, 100.0], dtype=np.float32)
intervention_mask = np.array([True, False, False, False, True])
threshold = module._compute_threshold(advantages, intervention_mask)
# Only non-intervention: [-1.0, 0.0, 1.0], median = 0.0
expected = float(np.percentile([-1.0, 0.0, 1.0], 50))
assert abs(threshold - expected) < 1e-6
def test_missing_mc_return_raises():
"""Module raises if mc_return column is missing from dataset."""
import pandas as pd
cfg = AdvantageConfig(value_function_path="/fake/vf")
module = AdvantageModule(config=cfg)
module._model = MagicMock()
module._preprocessor = MagicMock()
record = EpisodeRecord(
episode_index=0,
episode_task="test",
frame_timestamps=(0.0, 0.1),
frame_indices=(0, 1),
data_path=Path("/fake/data.parquet"),
row_offset=0,
row_count=2,
)
record._frames_df_cache = pd.DataFrame({"episode_index": [0, 0], "frame_index": [0, 1]})
with pytest.raises(KeyError, match="mc_return"):
module.compute_advantages_for_episode(record)
@@ -30,6 +30,7 @@ pytest.importorskip("pandas", reason="pandas is required (install lerobot[datase
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
AdvantageConfig,
AnnotationPipelineConfig,
InterjectionsConfig,
PlanConfig,
@@ -37,6 +38,7 @@ from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
)
from lerobot.annotations.steerable_pipeline.executor import Executor # noqa: E402
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
AdvantageModule,
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
@@ -132,6 +134,7 @@ def _build_executor() -> Executor:
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)