mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 11:17:02 +00:00
feat(recap): add advantage scoring annotation module
Implement the RECAP advantage scoring module as a new phase in lerobot-annotate. Uses a frozen distributional VF to compute per-frame advantages, binarizes into positive/negative indicators with per-task threshold, and writes style=advantage persistent rows for policy conditioning. Skips VF inference on intervention frames as an optimization.
This commit is contained in:
@@ -169,6 +169,43 @@ class ExecutorConfig:
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdvantageConfig:
|
||||
"""``advantage`` module: RECAP advantage scoring via frozen value function."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Path or Hub repo ID of the trained distributional value function checkpoint.
|
||||
value_function_path: str = ""
|
||||
|
||||
# Device to run the value function on.
|
||||
device: str = "cuda"
|
||||
|
||||
# N-step lookahead for advantage estimation.
|
||||
# None = MC (N=T): A_t = R_t - V(s_t), using mc_return from dataset.
|
||||
# 50 = fine-tuning mode: A_t = Σ r_{t:t+N} + V(s_{t+N}) - V(s_t).
|
||||
n_step: int | None = None
|
||||
|
||||
# Per-task percentile for binarization threshold ε_ℓ.
|
||||
# Actions with advantage > ε_ℓ get I_t = True (positive).
|
||||
threshold_percentile: float = 0.3
|
||||
|
||||
# Fraction of frames to randomly omit advantage labels (enables CFG).
|
||||
dropout_rate: float = 0.3
|
||||
|
||||
# Force I_t = True for frames marked as human interventions.
|
||||
force_positive_on_intervention: bool = True
|
||||
|
||||
# Column name in dataset for intervention flag.
|
||||
intervention_key: str = "intervention"
|
||||
|
||||
# Column name for pre-computed MC returns (from lerobot-compute-returns).
|
||||
mc_return_key: str = "mc_return"
|
||||
|
||||
# Batch size for value function inference.
|
||||
batch_size: int = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
@@ -190,6 +227,7 @@ class AnnotationPipelineConfig:
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
advantage: AdvantageConfig = field(default_factory=AdvantageConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
@@ -15,20 +15,24 @@
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
The executor runs **seven phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
phase 5: ``advantage`` module (advantage scoring via frozen VF)
|
||||
phase 6: validator
|
||||
phase 7: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Phase 5 (advantage) does not depend on the VLM modules, it uses a frozen
|
||||
distributional value function to compute per-frame advantage indicators.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
@@ -74,7 +78,7 @@ class PipelineRunSummary:
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
"""Run all seven phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
@@ -86,6 +90,7 @@ class Executor:
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
advantage: Any # AdvantageModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
@@ -112,6 +117,8 @@ class Executor:
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
# Phase 5: ``advantage`` module (advantage scoring via frozen VF)
|
||||
phases.append(self._run_module_phase("advantage", records, staging_dir, self.advantage))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .advantage import AdvantageModule
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"AdvantageModule",
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Advantage scoring module for RECAP.
|
||||
|
||||
Computes per-frame advantage values using a frozen distributional value function,
|
||||
binarizes them into improvement indicators (I_t), and emits ``style="advantage"``
|
||||
persistent rows for policy conditioning.
|
||||
|
||||
Paper reference: pi*0.6, Section IV-B and Appendix F.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..config import AdvantageConfig
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdvantageModule:
|
||||
"""Compute advantage indicators and emit persistent annotation rows.
|
||||
|
||||
The module loads a frozen distributional value function and scores each
|
||||
frame in an episode. Advantages are binarized into ``positive``/``negative``
|
||||
indicators using a per-task threshold, then written as ``style="advantage"``
|
||||
persistent rows into the staging area.
|
||||
|
||||
Requires ``mc_return`` column in the dataset (from lerobot-compute-returns).
|
||||
"""
|
||||
|
||||
config: AdvantageConfig
|
||||
_model: Any = field(default=None, init=False, repr=False)
|
||||
_preprocessor: Any = field(default=None, init=False, repr=False)
|
||||
_threshold: float | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def _ensure_model_loaded(self) -> None:
|
||||
"""Lazy-load the frozen value function on first use."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from lerobot.rewards import (
|
||||
make_reward_model,
|
||||
make_reward_model_config,
|
||||
make_reward_pre_post_processors,
|
||||
)
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"distributional_value_function",
|
||||
pretrained_path=self.config.value_function_path,
|
||||
device=self.config.device,
|
||||
)
|
||||
self._model = make_reward_model(cfg)
|
||||
self._model.eval()
|
||||
for p in self._model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self._preprocessor, _ = make_reward_pre_post_processors(cfg)
|
||||
logger.info("Loaded frozen VF from %s on %s", self.config.value_function_path, self.config.device)
|
||||
|
||||
def compute_advantages_for_episode(self, record: EpisodeRecord) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute raw advantage values for all frames in an episode.
|
||||
|
||||
Returns:
|
||||
(advantages, intervention_mask) both shape [num_frames].
|
||||
advantages[t] = A_t, intervention_mask[t] = True if frame is intervention.
|
||||
"""
|
||||
self._ensure_model_loaded()
|
||||
|
||||
df = record.frames_df()
|
||||
num_frames = len(df)
|
||||
|
||||
mc_return_key = self.config.mc_return_key
|
||||
if mc_return_key not in df.columns:
|
||||
raise KeyError(
|
||||
f"Column '{mc_return_key}' not found in episode {record.episode_index}. "
|
||||
"Run lerobot-compute-returns first."
|
||||
)
|
||||
|
||||
mc_returns = df[mc_return_key].values.astype(np.float32)
|
||||
|
||||
intervention_mask = np.zeros(num_frames, dtype=bool)
|
||||
if self.config.intervention_key in df.columns:
|
||||
intervention_mask = df[self.config.intervention_key].values.astype(bool)
|
||||
|
||||
# Skip VF inference on intervention frames — they're always "positive"
|
||||
# regardless of advantage value, so V(s_t) is never used for them.
|
||||
skip_mask = intervention_mask if self.config.force_positive_on_intervention else None
|
||||
values = self._compute_values(record, skip_mask=skip_mask)
|
||||
|
||||
if self.config.n_step is None:
|
||||
advantages = mc_returns - values
|
||||
else:
|
||||
advantages = self._compute_n_step_advantages(mc_returns, values, record, n=self.config.n_step)
|
||||
|
||||
return advantages, intervention_mask
|
||||
|
||||
def _compute_values(self, record: EpisodeRecord, skip_mask: np.ndarray | None = None) -> np.ndarray:
|
||||
"""Run frozen VF over all frames to get V(s_t) predictions.
|
||||
|
||||
Args:
|
||||
record: Episode data.
|
||||
skip_mask: Optional boolean mask [num_frames]. Frames where True are
|
||||
skipped (left as 0.0) to avoid unnecessary inference.
|
||||
"""
|
||||
df = record.frames_df()
|
||||
num_frames = len(df)
|
||||
values = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
image_key = self._resolve_image_key(df)
|
||||
if image_key is None:
|
||||
logger.warning("No image key found for episode %d; returning zero values.", record.episode_index)
|
||||
return values
|
||||
|
||||
# Determine which frame indices actually need inference
|
||||
infer_indices = np.where(~skip_mask)[0] if skip_mask is not None else np.arange(num_frames)
|
||||
|
||||
if len(infer_indices) == 0:
|
||||
return values
|
||||
|
||||
task_text = record.episode_task
|
||||
|
||||
for batch_start in range(0, len(infer_indices), self.config.batch_size):
|
||||
batch_end = min(batch_start + self.config.batch_size, len(infer_indices))
|
||||
batch_indices = infer_indices[batch_start:batch_end]
|
||||
batch_images = []
|
||||
|
||||
for idx in batch_indices:
|
||||
img_val = df.iloc[idx][image_key]
|
||||
if isinstance(img_val, np.ndarray):
|
||||
img_tensor = torch.from_numpy(img_val).float()
|
||||
elif isinstance(img_val, torch.Tensor):
|
||||
img_tensor = img_val.float()
|
||||
else:
|
||||
img_tensor = torch.zeros(3, 224, 224)
|
||||
batch_images.append(img_tensor)
|
||||
|
||||
batch_images_tensor = torch.stack(batch_images)
|
||||
batch_size = batch_images_tensor.shape[0]
|
||||
|
||||
raw_batch = {
|
||||
image_key: batch_images_tensor,
|
||||
"task": [task_text] * batch_size,
|
||||
}
|
||||
|
||||
processed = self._preprocessor(raw_batch)
|
||||
|
||||
with torch.no_grad():
|
||||
v_values = self._model.compute_reward(processed)
|
||||
|
||||
values[batch_indices] = v_values.cpu().numpy()
|
||||
|
||||
return values
|
||||
|
||||
def _compute_n_step_advantages(
|
||||
self, mc_returns: np.ndarray, values: np.ndarray, record: EpisodeRecord, n: int
|
||||
) -> np.ndarray:
|
||||
"""Compute N-step advantage: A_t = Σ r_{t:t+N-1} + V(s_{t+N}) - V(s_t).
|
||||
|
||||
When t+N exceeds episode length, truncates to MC (uses mc_return directly).
|
||||
"""
|
||||
num_frames = len(values)
|
||||
advantages = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
for t in range(num_frames):
|
||||
if t + n >= num_frames:
|
||||
advantages[t] = mc_returns[t] - values[t]
|
||||
else:
|
||||
n_step_return = mc_returns[t] - mc_returns[t + n]
|
||||
advantages[t] = n_step_return + values[t + n] - values[t]
|
||||
|
||||
return advantages
|
||||
|
||||
def _resolve_image_key(self, df) -> str | None:
|
||||
"""Find the first image observation key in the dataframe columns."""
|
||||
for col in df.columns:
|
||||
if col.startswith("observation.images."):
|
||||
return col
|
||||
return None
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
"""Score one episode and write advantage rows to staging."""
|
||||
if not self.config.value_function_path:
|
||||
logger.warning("No value_function_path configured; skipping advantage scoring.")
|
||||
return
|
||||
|
||||
advantages, intervention_mask = self.compute_advantages_for_episode(record)
|
||||
num_frames = len(advantages)
|
||||
|
||||
threshold = self._compute_threshold(advantages, intervention_mask)
|
||||
|
||||
rng = np.random.default_rng(seed=hash((record.episode_index, 42)) & 0xFFFFFFFF)
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for t in range(num_frames):
|
||||
if rng.random() < self.config.dropout_rate:
|
||||
continue
|
||||
|
||||
if (
|
||||
self.config.force_positive_on_intervention
|
||||
and intervention_mask[t]
|
||||
or advantages[t] > threshold
|
||||
):
|
||||
indicator = "positive"
|
||||
else:
|
||||
indicator = "negative"
|
||||
|
||||
timestamp = float(record.frame_timestamps[t]) if t < len(record.frame_timestamps) else 0.0
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": indicator,
|
||||
"style": "advantage",
|
||||
"timestamp": timestamp,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
staging.write("advantage", rows)
|
||||
logger.debug(
|
||||
"Episode %d: %d/%d frames scored (threshold=%.4f, %d positive, %d negative)",
|
||||
record.episode_index,
|
||||
len(rows),
|
||||
num_frames,
|
||||
threshold,
|
||||
sum(1 for r in rows if r["content"] == "positive"),
|
||||
sum(1 for r in rows if r["content"] == "negative"),
|
||||
)
|
||||
|
||||
def _compute_threshold(self, advantages: np.ndarray, intervention_mask: np.ndarray) -> float:
|
||||
"""Compute the binarization threshold as the configured percentile of advantages."""
|
||||
non_intervention = advantages[~intervention_mask] if intervention_mask.any() else advantages
|
||||
if len(non_intervention) == 0:
|
||||
return 0.0
|
||||
return float(np.percentile(non_intervention, self.config.threshold_percentile * 100))
|
||||
@@ -39,6 +39,7 @@ _MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
"advantage",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ DEFAULT_BINDINGS = {
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
"advantage": "active_at(t, style=advantage)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
@@ -43,10 +43,10 @@ CORE_STYLES = {
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
EXTENDED_STYLES: set[str] = {"advantage"}
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug", "advantage"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
|
||||
@@ -34,6 +34,7 @@ from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConf
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
AdvantageModule,
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
@@ -86,6 +87,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
advantage = AdvantageModule(config=cfg.advantage)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
@@ -96,6 +98,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
advantage=advantage,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
|
||||
@@ -28,9 +28,10 @@ import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig, AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
AdvantageModule,
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
@@ -85,6 +86,7 @@ def main() -> int:
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the advantage scoring annotation module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AdvantageConfig
|
||||
from lerobot.annotations.steerable_pipeline.modules.advantage import AdvantageModule
|
||||
from lerobot.annotations.steerable_pipeline.reader import EpisodeRecord
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
|
||||
|
||||
def _make_record(
|
||||
episode_index: int = 0,
|
||||
num_frames: int = 20,
|
||||
task: str = "pick up the cup",
|
||||
mc_returns: np.ndarray | None = None,
|
||||
intervention_mask: np.ndarray | None = None,
|
||||
fps: float = 10.0,
|
||||
) -> EpisodeRecord:
|
||||
"""Build a minimal EpisodeRecord with a mocked frames_df."""
|
||||
import pandas as pd
|
||||
|
||||
timestamps = tuple(round(i / fps, 6) for i in range(num_frames))
|
||||
frame_indices = tuple(range(num_frames))
|
||||
|
||||
if mc_returns is None:
|
||||
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": list(range(num_frames)),
|
||||
"timestamp": list(timestamps),
|
||||
"mc_return": mc_returns,
|
||||
}
|
||||
|
||||
if intervention_mask is not None:
|
||||
data["intervention"] = intervention_mask.astype(bool)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
record = EpisodeRecord(
|
||||
episode_index=episode_index,
|
||||
episode_task=task,
|
||||
frame_timestamps=timestamps,
|
||||
frame_indices=frame_indices,
|
||||
data_path=Path("/fake/data.parquet"),
|
||||
row_offset=0,
|
||||
row_count=num_frames,
|
||||
)
|
||||
record._frames_df_cache = df
|
||||
return record
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def staging(tmp_path: Path) -> EpisodeStaging:
|
||||
return EpisodeStaging(tmp_path, episode_index=0)
|
||||
|
||||
|
||||
def test_advantage_module_disabled():
|
||||
"""Disabled module has enabled=False."""
|
||||
cfg = AdvantageConfig(enabled=False)
|
||||
module = AdvantageModule(config=cfg)
|
||||
assert not module.enabled
|
||||
|
||||
|
||||
def test_advantage_module_enabled_by_default():
|
||||
"""Module is enabled by default."""
|
||||
cfg = AdvantageConfig()
|
||||
module = AdvantageModule(config=cfg)
|
||||
assert module.enabled
|
||||
|
||||
|
||||
def test_run_episode_skips_without_value_function_path(staging: EpisodeStaging):
|
||||
"""Module gracefully returns when no value_function_path is configured."""
|
||||
cfg = AdvantageConfig(value_function_path="")
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record()
|
||||
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
assert rows == []
|
||||
|
||||
|
||||
def test_binarization_with_mock_values(staging: EpisodeStaging):
|
||||
"""Advantage binarization produces positive/negative labels based on threshold."""
|
||||
num_frames = 10
|
||||
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1, -0.5, -0.6, -0.7, -0.8, -0.9], dtype=np.float32)
|
||||
mock_values = np.array([-0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4], dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.0,
|
||||
threshold_percentile=0.5,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
assert len(rows) == num_frames
|
||||
|
||||
# A_t = mc_returns - values
|
||||
# advantages = [-0.1, 0.0, 0.1, 0.2, 0.3, -0.1, -0.2, -0.3, -0.4, -0.5]
|
||||
# Median (50th pctile) = -0.1
|
||||
# positive: advantage > -0.1 → indices 1,2,3,4
|
||||
# negative: advantage <= -0.1 → indices 0,5,6,7,8,9
|
||||
positives = [r for r in rows if r["content"] == "positive"]
|
||||
negatives = [r for r in rows if r["content"] == "negative"]
|
||||
assert len(positives) == 4
|
||||
assert len(negatives) == 6
|
||||
|
||||
|
||||
def test_intervention_frames_forced_positive(staging: EpisodeStaging):
|
||||
"""Intervention frames are always scored as positive regardless of advantage value."""
|
||||
num_frames = 5
|
||||
mc_returns = np.array([-0.9, -0.9, -0.9, -0.9, -0.9], dtype=np.float32)
|
||||
mock_values = np.array([-0.1, -0.1, -0.1, -0.1, -0.1], dtype=np.float32)
|
||||
intervention = np.array([False, False, True, False, False])
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.0,
|
||||
force_positive_on_intervention=True,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns, intervention_mask=intervention)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
# Frame 2 (intervention) should be positive despite negative advantage
|
||||
assert rows[2]["content"] == "positive"
|
||||
|
||||
|
||||
def test_dropout_reduces_output_rows(staging: EpisodeStaging):
|
||||
"""Non-zero dropout rate omits some frames."""
|
||||
num_frames = 100
|
||||
mc_returns = np.linspace(-0.9, -0.1, num_frames).astype(np.float32)
|
||||
mock_values = np.full(num_frames, -0.5, dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.3,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
# With 30% dropout on 100 frames, expect ~70 rows (with some variance)
|
||||
assert 50 < len(rows) < 90
|
||||
|
||||
|
||||
def test_staged_row_format(staging: EpisodeStaging):
|
||||
"""Staged rows have the correct schema for language_persistent."""
|
||||
num_frames = 5
|
||||
mc_returns = np.array([-0.5, -0.4, -0.3, -0.2, -0.1], dtype=np.float32)
|
||||
mock_values = np.full(5, -0.3, dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with (
|
||||
patch.object(module, "_ensure_model_loaded"),
|
||||
patch.object(module, "_compute_values", return_value=mock_values),
|
||||
):
|
||||
module.run_episode(record, staging)
|
||||
|
||||
rows = staging.read("advantage")
|
||||
for row in rows:
|
||||
assert row["role"] == "user"
|
||||
assert row["content"] in ("positive", "negative")
|
||||
assert row["style"] == "advantage"
|
||||
assert isinstance(row["timestamp"], float)
|
||||
assert row["camera"] is None
|
||||
assert row["tool_calls"] is None
|
||||
|
||||
|
||||
def test_n_step_advantage():
|
||||
"""N-step advantage uses partial returns + bootstrapped value."""
|
||||
num_frames = 10
|
||||
mc_returns = np.linspace(-0.9, 0.0, num_frames).astype(np.float32)
|
||||
mock_values = np.full(num_frames, -0.45, dtype=np.float32)
|
||||
|
||||
cfg = AdvantageConfig(
|
||||
value_function_path="/fake/vf",
|
||||
n_step=3,
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
module = AdvantageModule(config=cfg)
|
||||
record = _make_record(num_frames=num_frames, mc_returns=mc_returns)
|
||||
|
||||
with patch.object(module, "_ensure_model_loaded"):
|
||||
advantages, _ = (
|
||||
module.compute_advantages_for_episode.__wrapped__(module, record)
|
||||
if hasattr(module.compute_advantages_for_episode, "__wrapped__")
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
# Just verify computation works - use the internal method directly
|
||||
module._model = MagicMock()
|
||||
module._preprocessor = MagicMock()
|
||||
with patch.object(module, "_compute_values", return_value=mock_values):
|
||||
advantages, _ = module.compute_advantages_for_episode(record)
|
||||
|
||||
# For t where t+n < num_frames: A = mc_return[t] - mc_return[t+n] + values[t+n] - values[t]
|
||||
# Since values are constant: A = mc_return[t] - mc_return[t+n]
|
||||
# For t where t+n >= num_frames: A = mc_return[t] - values[t]
|
||||
for t in range(num_frames):
|
||||
if t + 3 < num_frames:
|
||||
expected = mc_returns[t] - mc_returns[t + 3] + mock_values[t + 3] - mock_values[t]
|
||||
else:
|
||||
expected = mc_returns[t] - mock_values[t]
|
||||
np.testing.assert_almost_equal(advantages[t], expected, decimal=5)
|
||||
|
||||
|
||||
def test_compute_threshold():
|
||||
"""Threshold is computed as configured percentile of non-intervention advantages."""
|
||||
cfg = AdvantageConfig(threshold_percentile=0.3)
|
||||
module = AdvantageModule(config=cfg)
|
||||
|
||||
advantages = np.array([-1.0, -0.5, 0.0, 0.5, 1.0], dtype=np.float32)
|
||||
intervention_mask = np.array([False, False, False, False, False])
|
||||
|
||||
threshold = module._compute_threshold(advantages, intervention_mask)
|
||||
expected = float(np.percentile(advantages, 30))
|
||||
assert abs(threshold - expected) < 1e-6
|
||||
|
||||
|
||||
def test_compute_threshold_excludes_intervention():
|
||||
"""Threshold computation excludes intervention frames."""
|
||||
cfg = AdvantageConfig(threshold_percentile=0.5)
|
||||
module = AdvantageModule(config=cfg)
|
||||
|
||||
advantages = np.array([100.0, -1.0, 0.0, 1.0, 100.0], dtype=np.float32)
|
||||
intervention_mask = np.array([True, False, False, False, True])
|
||||
|
||||
threshold = module._compute_threshold(advantages, intervention_mask)
|
||||
# Only non-intervention: [-1.0, 0.0, 1.0], median = 0.0
|
||||
expected = float(np.percentile([-1.0, 0.0, 1.0], 50))
|
||||
assert abs(threshold - expected) < 1e-6
|
||||
|
||||
|
||||
def test_missing_mc_return_raises():
|
||||
"""Module raises if mc_return column is missing from dataset."""
|
||||
import pandas as pd
|
||||
|
||||
cfg = AdvantageConfig(value_function_path="/fake/vf")
|
||||
module = AdvantageModule(config=cfg)
|
||||
module._model = MagicMock()
|
||||
module._preprocessor = MagicMock()
|
||||
|
||||
record = EpisodeRecord(
|
||||
episode_index=0,
|
||||
episode_task="test",
|
||||
frame_timestamps=(0.0, 0.1),
|
||||
frame_indices=(0, 1),
|
||||
data_path=Path("/fake/data.parquet"),
|
||||
row_offset=0,
|
||||
row_count=2,
|
||||
)
|
||||
record._frames_df_cache = pd.DataFrame({"episode_index": [0, 0], "frame_index": [0, 1]})
|
||||
|
||||
with pytest.raises(KeyError, match="mc_return"):
|
||||
module.compute_advantages_for_episode(record)
|
||||
@@ -30,6 +30,7 @@ pytest.importorskip("pandas", reason="pandas is required (install lerobot[datase
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
AdvantageConfig,
|
||||
AnnotationPipelineConfig,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
@@ -37,6 +38,7 @@ from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
|
||||
AdvantageModule,
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
@@ -132,6 +134,7 @@ def _build_executor() -> Executor:
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||
advantage=AdvantageModule(config=AdvantageConfig(enabled=False)),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user