From e5c94c732fcfb0408abc5948c63eb4a1bd1c3f99 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 22 Jun 2026 12:17:37 +0200 Subject: [PATCH] feat(recap): add lerobot-compute-returns script to compute MC returns --- pyproject.toml | 1 + .../scripts/lerobot_compute_returns.py | 382 +++++++++++++ tests/scripts/test_compute_returns.py | 514 ++++++++++++++++++ 3 files changed, 897 insertions(+) create mode 100644 src/lerobot/scripts/lerobot_compute_returns.py create mode 100644 tests/scripts/test_compute_returns.py diff --git a/pyproject.toml b/pyproject.toml index 7608ad4a4..6fb2e4173 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -342,6 +342,7 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" lerobot-annotate="lerobot.scripts.lerobot_annotate:main" lerobot-rollout="lerobot.scripts.lerobot_rollout:main" +lerobot-compute-returns="lerobot.scripts.lerobot_compute_returns:main" # ---------------- Tool Configurations ---------------- diff --git a/src/lerobot/scripts/lerobot_compute_returns.py b/src/lerobot/scripts/lerobot_compute_returns.py new file mode 100644 index 000000000..d7d7ced7c --- /dev/null +++ b/src/lerobot/scripts/lerobot_compute_returns.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compute per-frame ``is_terminal`` and ``mc_return`` for a LeRobot dataset. + +Implements the sparse reward function from pi*0.6 / RECAP (Eq. 5): + + r_t = -1 for non-terminal steps + r_T = 0 for terminal success + r_T = -C_fail for terminal failure + +Monte Carlo returns are the cumulative sum from each step to the end of +the episode, normalized by ``max_episode_length`` so that values are bounded +to approximately (-1, 0). + +The columns are written directly into the dataset's parquet data shards as +flat per-frame scalars. These serve as training targets for the distributional +value function. + +Usage: + # Compute returns using the default "next.success" column (from lerobot-eval/rollout) + lerobot-compute-returns \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human_image + + # Override: treat all episodes as successful (demo-only datasets) + lerobot-compute-returns \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human_image \\ + --default-success true + + # Custom success key, failure penalty, and discount + lerobot-compute-returns \\ + --dataset-repo-id my_org/my_dataset \\ + --success-key episode_success \\ + --c-fail 100 \\ + --gamma 0.99 +""" + +from __future__ import annotations + +import argparse +import json +import logging +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +IS_TERMINAL_COL = "is_terminal" +MC_RETURN_COL = "mc_return" + + +@dataclass +class ComputeReturnsConfig: + """Configuration for the returns computation script.""" + + dataset_repo_id: str = "" + root: str | None = None + success_key: str = "next.success" + default_success: bool | None = None + max_episode_length: int | None = None + c_fail: float = 50.0 + gamma: float = 1.0 + episodes: list[int] = field(default_factory=list) + force: bool = False + + +def _get_episode_success( + episode_table: pa.Table, + success_key: str, + default_success: bool | None, +) -> bool: + """Determine whether an episode was successful. + + Priority: + 1. If ``default_success`` is set, use it unconditionally. + 2. Look for ``success_key`` in the parquet columns and reduce with any(). + 3. Fall back to True (assume success for demo datasets). + """ + if default_success is not None: + return default_success + + if success_key in episode_table.column_names: + col = episode_table.column(success_key) + for val in col: + py_val = val.as_py() + if isinstance(py_val, bool) and py_val: + return True + if isinstance(py_val, (int, float)) and py_val: + return True + return False + + return True + + +def compute_episode_returns( + num_frames: int, + success: bool, + c_fail: float, + gamma: float, + max_episode_length: int, +) -> tuple[np.ndarray, np.ndarray]: + """Compute is_terminal and mc_return arrays for a single episode. + + Args: + num_frames: Number of frames in the episode. + success: Whether the episode ended successfully. + c_fail: Failure penalty constant. + gamma: Discount factor (1.0 = undiscounted). + max_episode_length: Normalization horizon H. + + Returns: + Tuple of (is_terminal, mc_return) arrays, each of length num_frames. + """ + horizon = max_episode_length + + rewards = np.full(num_frames, -1.0 / horizon, dtype=np.float64) + + if success: + rewards[-1] = 0.0 + else: + rewards[-1] = -c_fail / horizon + + is_terminal = np.zeros(num_frames, dtype=bool) + is_terminal[-1] = True + + if gamma == 1.0: + # Reverse cumulative sum + mc_return = np.cumsum(rewards[::-1])[::-1].astype(np.float32) + else: + mc_return = np.zeros(num_frames, dtype=np.float64) + mc_return[-1] = rewards[-1] + for t in range(num_frames - 2, -1, -1): + mc_return[t] = rewards[t] + gamma * mc_return[t + 1] + mc_return = mc_return.astype(np.float32) + + return is_terminal, mc_return + + +def compute_returns(config: ComputeReturnsConfig) -> Path: + """Compute returns and write them into parquet shards.""" + from lerobot.datasets import LeRobotDataset + + logger.info(f"Loading dataset: {config.dataset_repo_id}") + kwargs = {"repo_id": config.dataset_repo_id, "download_videos": False} + if config.root: + kwargs["root"] = config.root + dataset = LeRobotDataset(**kwargs) + + meta = dataset.meta + root = Path(meta.root) + logger.info(f"Dataset root: {root}") + logger.info(f"Episodes: {meta.total_episodes}, Frames: {meta.total_frames}") + + episode_indices = config.episodes if config.episodes else list(range(meta.total_episodes)) + + if config.max_episode_length is not None: + max_ep_len = config.max_episode_length + else: + max_ep_len = max(int(meta.episodes[i]["length"]) for i in episode_indices) + logger.info(f"Normalization horizon (max_episode_length): {max_ep_len}") + + parquet_files_to_rewrite: dict[Path, list[int]] = {} + for ep_idx in episode_indices: + rel_path = meta.get_data_file_path(ep_idx) + abs_path = root / rel_path + parquet_files_to_rewrite.setdefault(abs_path, []).append(ep_idx) + + logger.info(f"Parquet shards to rewrite: {len(parquet_files_to_rewrite)}") + + for parquet_path, ep_indices_in_file in tqdm(parquet_files_to_rewrite.items(), desc="Processing shards"): + table = pq.read_table(parquet_path) + + if not config.force and IS_TERMINAL_COL in table.column_names: + logger.info(f"Skipping {parquet_path.name} (already has {IS_TERMINAL_COL})") + continue + + all_is_terminal = np.zeros(len(table), dtype=bool) + all_mc_return = np.zeros(len(table), dtype=np.float32) + + episode_col = table.column("episode_index").to_pylist() + + for ep_idx in ep_indices_in_file: + ep_info = meta.episodes[ep_idx] + ep_from = int(ep_info["dataset_from_index"]) + ep_to = int(ep_info["dataset_to_index"]) + ep_len = ep_to - ep_from + + mask = np.array([v == ep_idx for v in episode_col], dtype=bool) + local_indices = np.where(mask)[0] + + if len(local_indices) != ep_len: + logger.warning( + f"Episode {ep_idx}: expected {ep_len} frames in shard, " + f"found {len(local_indices)}. Using found count." + ) + ep_len = len(local_indices) + + if ep_len == 0: + continue + + ep_subtable = table.filter(mask) + success = _get_episode_success(ep_subtable, config.success_key, config.default_success) + + is_terminal, mc_return = compute_episode_returns( + num_frames=ep_len, + success=success, + c_fail=config.c_fail, + gamma=config.gamma, + max_episode_length=max_ep_len, + ) + + all_is_terminal[local_indices] = is_terminal + all_mc_return[local_indices] = mc_return + + if IS_TERMINAL_COL in table.column_names: + table = table.drop(IS_TERMINAL_COL) + if MC_RETURN_COL in table.column_names: + table = table.drop(MC_RETURN_COL) + + table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal)) + table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return)) + + pq.write_table(table, parquet_path) + + _update_info_json(root, meta) + + logger.info("Done. Columns written: is_terminal, mc_return") + return root + + +def _update_info_json(root: Path, meta) -> None: + """Add is_terminal and mc_return to the dataset's info.json features.""" + info_path = root / "meta" / "info.json" + if not info_path.exists(): + logger.warning(f"info.json not found at {info_path}, skipping metadata update.") + return + + info = json.loads(info_path.read_text()) + features = info.get("features", {}) + changed = False + + if IS_TERMINAL_COL not in features: + features[IS_TERMINAL_COL] = { + "dtype": "bool", + "shape": [1], + "names": None, + } + changed = True + + if MC_RETURN_COL not in features: + features[MC_RETURN_COL] = { + "dtype": "float32", + "shape": [1], + "names": None, + } + changed = True + + if changed: + info["features"] = features + info_path.write_text(json.dumps(info, indent=2) + "\n") + logger.info("Updated meta/info.json with is_terminal and mc_return features.") + + +def main(): + parser = argparse.ArgumentParser( + description="Compute per-frame is_terminal and mc_return for a LeRobot dataset.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use the 'success' column from the dataset + lerobot-compute-returns --dataset-repo-id lerobot/aloha_sim_insertion_human_image + + # Override all episodes as successful (demo-only data) + lerobot-compute-returns --dataset-repo-id my_org/my_dataset --default-success true + + # Custom failure penalty + lerobot-compute-returns --dataset-repo-id my_org/my_dataset --c-fail 100 + """, + ) + parser.add_argument( + "--dataset-repo-id", + type=str, + required=True, + help="HuggingFace dataset repo id or local path.", + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="Local root directory override for the dataset.", + ) + parser.add_argument( + "--success-key", + type=str, + default="next.success", + help="Column name in parquet that indicates episode success (default: 'next.success').", + ) + parser.add_argument( + "--default-success", + type=str, + default=None, + choices=["true", "false"], + help="Override success for all episodes ('true' or 'false').", + ) + parser.add_argument( + "--max-episode-length", + type=int, + default=None, + help="Normalization horizon H. If not set, uses max episode length in dataset.", + ) + parser.add_argument( + "--c-fail", + type=float, + default=50.0, + help="Failure penalty constant (default: 50.0).", + ) + parser.add_argument( + "--gamma", + type=float, + default=1.0, + help="Discount factor (default: 1.0, undiscounted).", + ) + parser.add_argument( + "--episodes", + type=int, + nargs="+", + default=None, + help="Process only these episode indices (default: all).", + ) + parser.add_argument( + "--force", + action="store_true", + help="Overwrite existing is_terminal/mc_return columns.", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + default_success = None + if args.default_success is not None: + default_success = args.default_success.lower() == "true" + + config = ComputeReturnsConfig( + dataset_repo_id=args.dataset_repo_id, + root=args.root, + success_key=args.success_key, + default_success=default_success, + max_episode_length=args.max_episode_length, + c_fail=args.c_fail, + gamma=args.gamma, + episodes=args.episodes or [], + force=args.force, + ) + + root = compute_returns(config) + logger.info(f"Returns computed and written to: {root}") + logger.info(f" Columns added: {IS_TERMINAL_COL}, {MC_RETURN_COL}") + logger.info("To train the distributional value function, these columns") + logger.info("will be read as flat batch keys during training.") + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/test_compute_returns.py b/tests/scripts/test_compute_returns.py new file mode 100644 index 000000000..88af16f06 --- /dev/null +++ b/tests/scripts/test_compute_returns.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for lerobot-compute-returns script.""" + +import json +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from lerobot.scripts.lerobot_compute_returns import ( + IS_TERMINAL_COL, + MC_RETURN_COL, + ComputeReturnsConfig, + _get_episode_success, + compute_episode_returns, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def parquet_dataset(tmp_path): + """Build a minimal parquet shard + info.json for testing I/O logic. + + Mirrors the lerobot-rollout DAgger convention: ``next.success`` is False + on all frames except the terminal frame of successful episodes. + Even episodes are successful, odd episodes are failures. + """ + num_episodes = 3 + frames_per_ep = 10 + + root = tmp_path / "test_dataset" + data_dir = root / "data" / "chunk-000" + meta_dir = root / "meta" + data_dir.mkdir(parents=True) + meta_dir.mkdir(parents=True) + + all_rows = [] + episodes_meta = [] + global_idx = 0 + for ep in range(num_episodes): + ep_from = global_idx + is_successful = ep % 2 == 0 + for frame in range(frames_per_ep): + is_last_frame = frame == frames_per_ep - 1 + all_rows.append( + { + "episode_index": ep, + "frame_index": frame, + "index": global_idx, + "next.success": is_successful and is_last_frame, + } + ) + global_idx += 1 + ep_to = global_idx + episodes_meta.append( + { + "episode_index": ep, + "length": frames_per_ep, + "dataset_from_index": ep_from, + "dataset_to_index": ep_to, + } + ) + + table = pa.table( + { + "episode_index": [r["episode_index"] for r in all_rows], + "frame_index": [r["frame_index"] for r in all_rows], + "index": [r["index"] for r in all_rows], + "next.success": [r["next.success"] for r in all_rows], + } + ) + + parquet_path = data_dir / "episode_000000.parquet" + pq.write_table(table, parquet_path) + + info = { + "codebase_version": "v3.0", + "total_episodes": num_episodes, + "total_frames": global_idx, + "fps": 30, + "features": { + "episode_index": {"dtype": "int64", "shape": [1], "names": None}, + "frame_index": {"dtype": "int64", "shape": [1], "names": None}, + "index": {"dtype": "int64", "shape": [1], "names": None}, + "next.success": {"dtype": "bool", "shape": [1], "names": None}, + }, + } + (meta_dir / "info.json").write_text(json.dumps(info, indent=2)) + + return root, parquet_path, episodes_meta + + +def _rewrite_shard(parquet_path: Path, episodes_meta: list[dict], config: ComputeReturnsConfig): + """Rewrite a single parquet shard using the core logic from compute_returns.""" + table = pq.read_table(parquet_path) + + if not config.force and IS_TERMINAL_COL in table.column_names: + return + + all_is_terminal = np.zeros(len(table), dtype=bool) + all_mc_return = np.zeros(len(table), dtype=np.float32) + episode_col = table.column("episode_index").to_pylist() + + for ep_info in episodes_meta: + ep_idx = ep_info["episode_index"] + ep_len = ep_info["length"] + + mask = np.array([v == ep_idx for v in episode_col], dtype=bool) + local_indices = np.where(mask)[0] + + ep_subtable = table.filter(mask) + success = _get_episode_success(ep_subtable, config.success_key, config.default_success) + + is_terminal, mc_return = compute_episode_returns( + num_frames=ep_len, + success=success, + c_fail=config.c_fail, + gamma=config.gamma, + max_episode_length=config.max_episode_length or ep_len, + ) + + all_is_terminal[local_indices] = is_terminal + all_mc_return[local_indices] = mc_return + + if IS_TERMINAL_COL in table.column_names: + table = table.drop(IS_TERMINAL_COL) + if MC_RETURN_COL in table.column_names: + table = table.drop(MC_RETURN_COL) + + table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal)) + table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return)) + pq.write_table(table, parquet_path) + + +# --------------------------------------------------------------------------- +# Tests: compute_episode_returns (pure math, no I/O) +# --------------------------------------------------------------------------- + + +def test_successful_episode_terminal_reward_is_zero(): + """Terminal MC return for a successful episode should be 0.""" + _, mc_return = compute_episode_returns( + num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10 + ) + assert mc_return[-1] == pytest.approx(0.0, abs=1e-6) + + +def test_failed_episode_terminal_reward_reflects_cfail(): + """Terminal MC return for a failed episode should be -C_fail / H.""" + horizon = 100 + c_fail = 50.0 + _, mc_return = compute_episode_returns( + num_frames=10, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon + ) + assert mc_return[-1] == pytest.approx(-c_fail / horizon, abs=1e-5) + + +def test_is_terminal_only_last_frame(): + """Only the last frame of an episode should be marked terminal.""" + is_terminal, _ = compute_episode_returns( + num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20 + ) + assert is_terminal[-1] == True # noqa: E712 + assert not any(is_terminal[:-1]) + + +def test_mc_return_monotonically_increases_for_success(): + """For a successful undiscounted episode, returns should increase toward 0.""" + _, mc_return = compute_episode_returns( + num_frames=50, success=True, c_fail=50.0, gamma=1.0, max_episode_length=50 + ) + for i in range(len(mc_return) - 1): + assert mc_return[i] <= mc_return[i + 1] + + +def test_mc_return_bounded_negative_to_zero(): + """MC returns for successful episodes should be in (-1, 0].""" + _, mc_return = compute_episode_returns( + num_frames=100, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100 + ) + assert mc_return[-1] == pytest.approx(0.0, abs=1e-6) + assert all(v <= 0.0 for v in mc_return) + assert all(v >= -1.0 - 1e-6 for v in mc_return) + + +def test_first_frame_return_success(): + """First frame return for successful episode equals -(N-1)/H.""" + num_frames = 10 + horizon = 10 + _, mc_return = compute_episode_returns( + num_frames=num_frames, success=True, c_fail=50.0, gamma=1.0, max_episode_length=horizon + ) + expected = -(num_frames - 1) / horizon + assert mc_return[0] == pytest.approx(expected, abs=1e-5) + + +def test_first_frame_return_failure(): + """First frame return for failed episode includes the failure penalty.""" + num_frames = 10 + horizon = 100 + c_fail = 50.0 + _, mc_return = compute_episode_returns( + num_frames=num_frames, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon + ) + expected = (-(num_frames - 1) / horizon) + (-c_fail / horizon) + assert mc_return[0] == pytest.approx(expected, abs=1e-5) + + +def test_discount_factor_less_than_one(): + """Discount factor < 1 should make earlier frames have smaller magnitude.""" + _, mc_undiscounted = compute_episode_returns( + num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20 + ) + _, mc_discounted = compute_episode_returns( + num_frames=20, success=True, c_fail=50.0, gamma=0.99, max_episode_length=20 + ) + assert abs(mc_discounted[0]) < abs(mc_undiscounted[0]) + + +def test_single_frame_episode_success(): + """Single-frame successful episode: return should be 0.""" + is_terminal, mc_return = compute_episode_returns( + num_frames=1, success=True, c_fail=50.0, gamma=1.0, max_episode_length=1 + ) + assert mc_return[0] == pytest.approx(0.0, abs=1e-6) + assert is_terminal[0] == True # noqa: E712 + + +def test_single_frame_episode_failure(): + """Single-frame failed episode: return should be -C_fail/H.""" + horizon = 100 + c_fail = 50.0 + is_terminal, mc_return = compute_episode_returns( + num_frames=1, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon + ) + assert mc_return[0] == pytest.approx(-c_fail / horizon, abs=1e-5) + assert is_terminal[0] == True # noqa: E712 + + +def test_horizon_normalization_scales_returns(): + """Larger horizon should scale down the per-step penalty.""" + _, mc_small_h = compute_episode_returns( + num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10 + ) + _, mc_large_h = compute_episode_returns( + num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100 + ) + assert abs(mc_large_h[0]) < abs(mc_small_h[0]) + + +# --------------------------------------------------------------------------- +# Tests: _get_episode_success (in-memory PyArrow tables) +# --------------------------------------------------------------------------- + + +def test_default_success_overrides_column(): + """default_success should override any column value.""" + table = pa.table({"next.success": [True, True, True]}) + assert _get_episode_success(table, "next.success", default_success=False) is False + + +def test_reads_bool_column(): + """Should detect success via any() reduction over the column.""" + table_success = pa.table({"next.success": [False, False, True]}) + table_fail = pa.table({"next.success": [False, False, False]}) + assert _get_episode_success(table_success, "next.success", None) is True + assert _get_episode_success(table_fail, "next.success", None) is False + + +def test_reads_int_column(): + """Should interpret integer success column (0/1) as bool via any().""" + table = pa.table({"task_success": [0, 0, 1]}) + assert _get_episode_success(table, "task_success", None) is True + + +def test_all_zeros_means_failure(): + """An episode with all-zero success values is a failure.""" + table = pa.table({"next.success": [0, 0, 0]}) + assert _get_episode_success(table, "next.success", None) is False + + +def test_missing_column_defaults_to_true(): + """When success column is missing, assume success (demo data).""" + table = pa.table({"frame_index": [0, 1, 2]}) + assert _get_episode_success(table, "next.success", None) is True + + +# --------------------------------------------------------------------------- +# Tests: parquet rewriting (integration, writes to disk) +# --------------------------------------------------------------------------- + + +def test_writes_columns_to_parquet(parquet_dataset): + """The rewrite logic should add is_terminal and mc_return columns.""" + root, parquet_path, episodes_meta = parquet_dataset + + table_before = pq.read_table(parquet_path) + assert IS_TERMINAL_COL not in table_before.column_names + assert MC_RETURN_COL not in table_before.column_names + + config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True) + _rewrite_shard(parquet_path, episodes_meta, config) + + table_after = pq.read_table(parquet_path) + assert IS_TERMINAL_COL in table_after.column_names + assert MC_RETURN_COL in table_after.column_names + + +def test_terminal_frames_correct(parquet_dataset): + """Only the last frame of each episode should be terminal.""" + root, parquet_path, episodes_meta = parquet_dataset + + config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True) + _rewrite_shard(parquet_path, episodes_meta, config) + + table = pq.read_table(parquet_path) + is_terminal = table.column(IS_TERMINAL_COL).to_pylist() + terminal_indices = [i for i, v in enumerate(is_terminal) if v] + assert terminal_indices == [9, 19, 29] + + +def test_success_episodes_return_zero_at_terminal(tmp_path): + """Successful episodes (ep 0) should have mc_return=0 at terminal.""" + num_episodes = 2 + frames_per_ep = 5 + + root = tmp_path / "test_dataset" + data_dir = root / "data" / "chunk-000" + meta_dir = root / "meta" + data_dir.mkdir(parents=True) + meta_dir.mkdir(parents=True) + + all_rows = [] + episodes_meta = [] + global_idx = 0 + for ep in range(num_episodes): + ep_from = global_idx + is_successful = ep % 2 == 0 + for frame in range(frames_per_ep): + is_last_frame = frame == frames_per_ep - 1 + all_rows.append( + { + "episode_index": ep, + "frame_index": frame, + "index": global_idx, + "next.success": is_successful and is_last_frame, + } + ) + global_idx += 1 + episodes_meta.append( + { + "episode_index": ep, + "length": frames_per_ep, + "dataset_from_index": ep_from, + "dataset_to_index": global_idx, + } + ) + + table = pa.table( + { + "episode_index": [r["episode_index"] for r in all_rows], + "frame_index": [r["frame_index"] for r in all_rows], + "index": [r["index"] for r in all_rows], + "next.success": [r["next.success"] for r in all_rows], + } + ) + parquet_path = data_dir / "episode_000000.parquet" + pq.write_table(table, parquet_path) + + info = { + "codebase_version": "v3.0", + "total_episodes": num_episodes, + "total_frames": global_idx, + "fps": 30, + "features": { + "episode_index": {"dtype": "int64", "shape": [1], "names": None}, + "frame_index": {"dtype": "int64", "shape": [1], "names": None}, + "index": {"dtype": "int64", "shape": [1], "names": None}, + "next.success": {"dtype": "bool", "shape": [1], "names": None}, + }, + } + (meta_dir / "info.json").write_text(json.dumps(info, indent=2)) + + config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, force=True) + _rewrite_shard(parquet_path, episodes_meta, config) + + table = pq.read_table(parquet_path) + mc_return = table.column(MC_RETURN_COL).to_pylist() + assert mc_return[4] == pytest.approx(0.0, abs=1e-5) + + +def test_failed_episodes_have_negative_terminal(tmp_path): + """Failed episodes (ep 1) should have mc_return < 0 at terminal.""" + num_episodes = 2 + frames_per_ep = 5 + + root = tmp_path / "test_dataset" + data_dir = root / "data" / "chunk-000" + meta_dir = root / "meta" + data_dir.mkdir(parents=True) + meta_dir.mkdir(parents=True) + + all_rows = [] + episodes_meta = [] + global_idx = 0 + for ep in range(num_episodes): + ep_from = global_idx + is_successful = ep % 2 == 0 + for frame in range(frames_per_ep): + is_last_frame = frame == frames_per_ep - 1 + all_rows.append( + { + "episode_index": ep, + "frame_index": frame, + "index": global_idx, + "next.success": is_successful and is_last_frame, + } + ) + global_idx += 1 + episodes_meta.append( + { + "episode_index": ep, + "length": frames_per_ep, + "dataset_from_index": ep_from, + "dataset_to_index": global_idx, + } + ) + + table = pa.table( + { + "episode_index": [r["episode_index"] for r in all_rows], + "frame_index": [r["frame_index"] for r in all_rows], + "index": [r["index"] for r in all_rows], + "next.success": [r["next.success"] for r in all_rows], + } + ) + parquet_path = data_dir / "episode_000000.parquet" + pq.write_table(table, parquet_path) + + config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, c_fail=50.0, force=True) + _rewrite_shard(parquet_path, episodes_meta, config) + + table = pq.read_table(parquet_path) + mc_return = table.column(MC_RETURN_COL).to_pylist() + assert mc_return[9] < 0.0 + + +def test_idempotent_with_force_flag(parquet_dataset): + """Running twice with force should produce identical results.""" + root, parquet_path, episodes_meta = parquet_dataset + + config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True) + _rewrite_shard(parquet_path, episodes_meta, config) + table1 = pq.read_table(parquet_path) + mc1 = table1.column(MC_RETURN_COL).to_pylist() + + _rewrite_shard(parquet_path, episodes_meta, config) + table2 = pq.read_table(parquet_path) + mc2 = table2.column(MC_RETURN_COL).to_pylist() + + assert mc1 == mc2 + + +def test_skips_if_columns_exist_without_force(parquet_dataset): + """Without force, existing columns should not be overwritten.""" + root, parquet_path, episodes_meta = parquet_dataset + + config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True) + _rewrite_shard(parquet_path, episodes_meta, config) + + table = pq.read_table(parquet_path) + original_mc = table.column(MC_RETURN_COL).to_pylist() + + config_no_force = ComputeReturnsConfig(success_key="next.success", max_episode_length=20, force=False) + _rewrite_shard(parquet_path, episodes_meta, config_no_force) + + table2 = pq.read_table(parquet_path) + assert table2.column(MC_RETURN_COL).to_pylist() == original_mc + + +def test_updates_info_json(parquet_dataset): + """info.json should be updated with is_terminal and mc_return features.""" + from lerobot.scripts.lerobot_compute_returns import _update_info_json + + root, parquet_path, episodes_meta = parquet_dataset + + _update_info_json(root, None) + + info_path = root / "meta" / "info.json" + info = json.loads(info_path.read_text()) + assert IS_TERMINAL_COL in info["features"] + assert MC_RETURN_COL in info["features"] + assert info["features"][IS_TERMINAL_COL]["dtype"] == "bool" + assert info["features"][MC_RETURN_COL]["dtype"] == "float32"