feat(recap): add lerobot-compute-returns script to compute MC returns

This commit is contained in:
Khalil Meftah
2026-06-22 12:17:37 +02:00
parent c18b8277f1
commit e5c94c732f
3 changed files with 897 additions and 0 deletions
+1
View File
@@ -342,6 +342,7 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
lerobot-compute-returns="lerobot.scripts.lerobot_compute_returns:main"
# ---------------- Tool Configurations ----------------
@@ -0,0 +1,382 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute per-frame ``is_terminal`` and ``mc_return`` for a LeRobot dataset.
Implements the sparse reward function from pi*0.6 / RECAP (Eq. 5):
r_t = -1 for non-terminal steps
r_T = 0 for terminal success
r_T = -C_fail for terminal failure
Monte Carlo returns are the cumulative sum from each step to the end of
the episode, normalized by ``max_episode_length`` so that values are bounded
to approximately (-1, 0).
The columns are written directly into the dataset's parquet data shards as
flat per-frame scalars. These serve as training targets for the distributional
value function.
Usage:
# Compute returns using the default "next.success" column (from lerobot-eval/rollout)
lerobot-compute-returns \\
--dataset-repo-id lerobot/aloha_sim_insertion_human_image
# Override: treat all episodes as successful (demo-only datasets)
lerobot-compute-returns \\
--dataset-repo-id lerobot/aloha_sim_insertion_human_image \\
--default-success true
# Custom success key, failure penalty, and discount
lerobot-compute-returns \\
--dataset-repo-id my_org/my_dataset \\
--success-key episode_success \\
--c-fail 100 \\
--gamma 0.99
"""
from __future__ import annotations
import argparse
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
logger = logging.getLogger(__name__)
IS_TERMINAL_COL = "is_terminal"
MC_RETURN_COL = "mc_return"
@dataclass
class ComputeReturnsConfig:
"""Configuration for the returns computation script."""
dataset_repo_id: str = ""
root: str | None = None
success_key: str = "next.success"
default_success: bool | None = None
max_episode_length: int | None = None
c_fail: float = 50.0
gamma: float = 1.0
episodes: list[int] = field(default_factory=list)
force: bool = False
def _get_episode_success(
episode_table: pa.Table,
success_key: str,
default_success: bool | None,
) -> bool:
"""Determine whether an episode was successful.
Priority:
1. If ``default_success`` is set, use it unconditionally.
2. Look for ``success_key`` in the parquet columns and reduce with any().
3. Fall back to True (assume success for demo datasets).
"""
if default_success is not None:
return default_success
if success_key in episode_table.column_names:
col = episode_table.column(success_key)
for val in col:
py_val = val.as_py()
if isinstance(py_val, bool) and py_val:
return True
if isinstance(py_val, (int, float)) and py_val:
return True
return False
return True
def compute_episode_returns(
num_frames: int,
success: bool,
c_fail: float,
gamma: float,
max_episode_length: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute is_terminal and mc_return arrays for a single episode.
Args:
num_frames: Number of frames in the episode.
success: Whether the episode ended successfully.
c_fail: Failure penalty constant.
gamma: Discount factor (1.0 = undiscounted).
max_episode_length: Normalization horizon H.
Returns:
Tuple of (is_terminal, mc_return) arrays, each of length num_frames.
"""
horizon = max_episode_length
rewards = np.full(num_frames, -1.0 / horizon, dtype=np.float64)
if success:
rewards[-1] = 0.0
else:
rewards[-1] = -c_fail / horizon
is_terminal = np.zeros(num_frames, dtype=bool)
is_terminal[-1] = True
if gamma == 1.0:
# Reverse cumulative sum
mc_return = np.cumsum(rewards[::-1])[::-1].astype(np.float32)
else:
mc_return = np.zeros(num_frames, dtype=np.float64)
mc_return[-1] = rewards[-1]
for t in range(num_frames - 2, -1, -1):
mc_return[t] = rewards[t] + gamma * mc_return[t + 1]
mc_return = mc_return.astype(np.float32)
return is_terminal, mc_return
def compute_returns(config: ComputeReturnsConfig) -> Path:
"""Compute returns and write them into parquet shards."""
from lerobot.datasets import LeRobotDataset
logger.info(f"Loading dataset: {config.dataset_repo_id}")
kwargs = {"repo_id": config.dataset_repo_id, "download_videos": False}
if config.root:
kwargs["root"] = config.root
dataset = LeRobotDataset(**kwargs)
meta = dataset.meta
root = Path(meta.root)
logger.info(f"Dataset root: {root}")
logger.info(f"Episodes: {meta.total_episodes}, Frames: {meta.total_frames}")
episode_indices = config.episodes if config.episodes else list(range(meta.total_episodes))
if config.max_episode_length is not None:
max_ep_len = config.max_episode_length
else:
max_ep_len = max(int(meta.episodes[i]["length"]) for i in episode_indices)
logger.info(f"Normalization horizon (max_episode_length): {max_ep_len}")
parquet_files_to_rewrite: dict[Path, list[int]] = {}
for ep_idx in episode_indices:
rel_path = meta.get_data_file_path(ep_idx)
abs_path = root / rel_path
parquet_files_to_rewrite.setdefault(abs_path, []).append(ep_idx)
logger.info(f"Parquet shards to rewrite: {len(parquet_files_to_rewrite)}")
for parquet_path, ep_indices_in_file in tqdm(parquet_files_to_rewrite.items(), desc="Processing shards"):
table = pq.read_table(parquet_path)
if not config.force and IS_TERMINAL_COL in table.column_names:
logger.info(f"Skipping {parquet_path.name} (already has {IS_TERMINAL_COL})")
continue
all_is_terminal = np.zeros(len(table), dtype=bool)
all_mc_return = np.zeros(len(table), dtype=np.float32)
episode_col = table.column("episode_index").to_pylist()
for ep_idx in ep_indices_in_file:
ep_info = meta.episodes[ep_idx]
ep_from = int(ep_info["dataset_from_index"])
ep_to = int(ep_info["dataset_to_index"])
ep_len = ep_to - ep_from
mask = np.array([v == ep_idx for v in episode_col], dtype=bool)
local_indices = np.where(mask)[0]
if len(local_indices) != ep_len:
logger.warning(
f"Episode {ep_idx}: expected {ep_len} frames in shard, "
f"found {len(local_indices)}. Using found count."
)
ep_len = len(local_indices)
if ep_len == 0:
continue
ep_subtable = table.filter(mask)
success = _get_episode_success(ep_subtable, config.success_key, config.default_success)
is_terminal, mc_return = compute_episode_returns(
num_frames=ep_len,
success=success,
c_fail=config.c_fail,
gamma=config.gamma,
max_episode_length=max_ep_len,
)
all_is_terminal[local_indices] = is_terminal
all_mc_return[local_indices] = mc_return
if IS_TERMINAL_COL in table.column_names:
table = table.drop(IS_TERMINAL_COL)
if MC_RETURN_COL in table.column_names:
table = table.drop(MC_RETURN_COL)
table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal))
table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return))
pq.write_table(table, parquet_path)
_update_info_json(root, meta)
logger.info("Done. Columns written: is_terminal, mc_return")
return root
def _update_info_json(root: Path, meta) -> None:
"""Add is_terminal and mc_return to the dataset's info.json features."""
info_path = root / "meta" / "info.json"
if not info_path.exists():
logger.warning(f"info.json not found at {info_path}, skipping metadata update.")
return
info = json.loads(info_path.read_text())
features = info.get("features", {})
changed = False
if IS_TERMINAL_COL not in features:
features[IS_TERMINAL_COL] = {
"dtype": "bool",
"shape": [1],
"names": None,
}
changed = True
if MC_RETURN_COL not in features:
features[MC_RETURN_COL] = {
"dtype": "float32",
"shape": [1],
"names": None,
}
changed = True
if changed:
info["features"] = features
info_path.write_text(json.dumps(info, indent=2) + "\n")
logger.info("Updated meta/info.json with is_terminal and mc_return features.")
def main():
parser = argparse.ArgumentParser(
description="Compute per-frame is_terminal and mc_return for a LeRobot dataset.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Use the 'success' column from the dataset
lerobot-compute-returns --dataset-repo-id lerobot/aloha_sim_insertion_human_image
# Override all episodes as successful (demo-only data)
lerobot-compute-returns --dataset-repo-id my_org/my_dataset --default-success true
# Custom failure penalty
lerobot-compute-returns --dataset-repo-id my_org/my_dataset --c-fail 100
""",
)
parser.add_argument(
"--dataset-repo-id",
type=str,
required=True,
help="HuggingFace dataset repo id or local path.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="Local root directory override for the dataset.",
)
parser.add_argument(
"--success-key",
type=str,
default="next.success",
help="Column name in parquet that indicates episode success (default: 'next.success').",
)
parser.add_argument(
"--default-success",
type=str,
default=None,
choices=["true", "false"],
help="Override success for all episodes ('true' or 'false').",
)
parser.add_argument(
"--max-episode-length",
type=int,
default=None,
help="Normalization horizon H. If not set, uses max episode length in dataset.",
)
parser.add_argument(
"--c-fail",
type=float,
default=50.0,
help="Failure penalty constant (default: 50.0).",
)
parser.add_argument(
"--gamma",
type=float,
default=1.0,
help="Discount factor (default: 1.0, undiscounted).",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Process only these episode indices (default: all).",
)
parser.add_argument(
"--force",
action="store_true",
help="Overwrite existing is_terminal/mc_return columns.",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
default_success = None
if args.default_success is not None:
default_success = args.default_success.lower() == "true"
config = ComputeReturnsConfig(
dataset_repo_id=args.dataset_repo_id,
root=args.root,
success_key=args.success_key,
default_success=default_success,
max_episode_length=args.max_episode_length,
c_fail=args.c_fail,
gamma=args.gamma,
episodes=args.episodes or [],
force=args.force,
)
root = compute_returns(config)
logger.info(f"Returns computed and written to: {root}")
logger.info(f" Columns added: {IS_TERMINAL_COL}, {MC_RETURN_COL}")
logger.info("To train the distributional value function, these columns")
logger.info("will be read as flat batch keys during training.")
if __name__ == "__main__":
main()
+514
View File
@@ -0,0 +1,514 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lerobot-compute-returns script."""
import json
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from lerobot.scripts.lerobot_compute_returns import (
IS_TERMINAL_COL,
MC_RETURN_COL,
ComputeReturnsConfig,
_get_episode_success,
compute_episode_returns,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def parquet_dataset(tmp_path):
"""Build a minimal parquet shard + info.json for testing I/O logic.
Mirrors the lerobot-rollout DAgger convention: ``next.success`` is False
on all frames except the terminal frame of successful episodes.
Even episodes are successful, odd episodes are failures.
"""
num_episodes = 3
frames_per_ep = 10
root = tmp_path / "test_dataset"
data_dir = root / "data" / "chunk-000"
meta_dir = root / "meta"
data_dir.mkdir(parents=True)
meta_dir.mkdir(parents=True)
all_rows = []
episodes_meta = []
global_idx = 0
for ep in range(num_episodes):
ep_from = global_idx
is_successful = ep % 2 == 0
for frame in range(frames_per_ep):
is_last_frame = frame == frames_per_ep - 1
all_rows.append(
{
"episode_index": ep,
"frame_index": frame,
"index": global_idx,
"next.success": is_successful and is_last_frame,
}
)
global_idx += 1
ep_to = global_idx
episodes_meta.append(
{
"episode_index": ep,
"length": frames_per_ep,
"dataset_from_index": ep_from,
"dataset_to_index": ep_to,
}
)
table = pa.table(
{
"episode_index": [r["episode_index"] for r in all_rows],
"frame_index": [r["frame_index"] for r in all_rows],
"index": [r["index"] for r in all_rows],
"next.success": [r["next.success"] for r in all_rows],
}
)
parquet_path = data_dir / "episode_000000.parquet"
pq.write_table(table, parquet_path)
info = {
"codebase_version": "v3.0",
"total_episodes": num_episodes,
"total_frames": global_idx,
"fps": 30,
"features": {
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
"index": {"dtype": "int64", "shape": [1], "names": None},
"next.success": {"dtype": "bool", "shape": [1], "names": None},
},
}
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
return root, parquet_path, episodes_meta
def _rewrite_shard(parquet_path: Path, episodes_meta: list[dict], config: ComputeReturnsConfig):
"""Rewrite a single parquet shard using the core logic from compute_returns."""
table = pq.read_table(parquet_path)
if not config.force and IS_TERMINAL_COL in table.column_names:
return
all_is_terminal = np.zeros(len(table), dtype=bool)
all_mc_return = np.zeros(len(table), dtype=np.float32)
episode_col = table.column("episode_index").to_pylist()
for ep_info in episodes_meta:
ep_idx = ep_info["episode_index"]
ep_len = ep_info["length"]
mask = np.array([v == ep_idx for v in episode_col], dtype=bool)
local_indices = np.where(mask)[0]
ep_subtable = table.filter(mask)
success = _get_episode_success(ep_subtable, config.success_key, config.default_success)
is_terminal, mc_return = compute_episode_returns(
num_frames=ep_len,
success=success,
c_fail=config.c_fail,
gamma=config.gamma,
max_episode_length=config.max_episode_length or ep_len,
)
all_is_terminal[local_indices] = is_terminal
all_mc_return[local_indices] = mc_return
if IS_TERMINAL_COL in table.column_names:
table = table.drop(IS_TERMINAL_COL)
if MC_RETURN_COL in table.column_names:
table = table.drop(MC_RETURN_COL)
table = table.append_column(IS_TERMINAL_COL, pa.array(all_is_terminal))
table = table.append_column(MC_RETURN_COL, pa.array(all_mc_return))
pq.write_table(table, parquet_path)
# ---------------------------------------------------------------------------
# Tests: compute_episode_returns (pure math, no I/O)
# ---------------------------------------------------------------------------
def test_successful_episode_terminal_reward_is_zero():
"""Terminal MC return for a successful episode should be 0."""
_, mc_return = compute_episode_returns(
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10
)
assert mc_return[-1] == pytest.approx(0.0, abs=1e-6)
def test_failed_episode_terminal_reward_reflects_cfail():
"""Terminal MC return for a failed episode should be -C_fail / H."""
horizon = 100
c_fail = 50.0
_, mc_return = compute_episode_returns(
num_frames=10, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
)
assert mc_return[-1] == pytest.approx(-c_fail / horizon, abs=1e-5)
def test_is_terminal_only_last_frame():
"""Only the last frame of an episode should be marked terminal."""
is_terminal, _ = compute_episode_returns(
num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20
)
assert is_terminal[-1] == True # noqa: E712
assert not any(is_terminal[:-1])
def test_mc_return_monotonically_increases_for_success():
"""For a successful undiscounted episode, returns should increase toward 0."""
_, mc_return = compute_episode_returns(
num_frames=50, success=True, c_fail=50.0, gamma=1.0, max_episode_length=50
)
for i in range(len(mc_return) - 1):
assert mc_return[i] <= mc_return[i + 1]
def test_mc_return_bounded_negative_to_zero():
"""MC returns for successful episodes should be in (-1, 0]."""
_, mc_return = compute_episode_returns(
num_frames=100, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100
)
assert mc_return[-1] == pytest.approx(0.0, abs=1e-6)
assert all(v <= 0.0 for v in mc_return)
assert all(v >= -1.0 - 1e-6 for v in mc_return)
def test_first_frame_return_success():
"""First frame return for successful episode equals -(N-1)/H."""
num_frames = 10
horizon = 10
_, mc_return = compute_episode_returns(
num_frames=num_frames, success=True, c_fail=50.0, gamma=1.0, max_episode_length=horizon
)
expected = -(num_frames - 1) / horizon
assert mc_return[0] == pytest.approx(expected, abs=1e-5)
def test_first_frame_return_failure():
"""First frame return for failed episode includes the failure penalty."""
num_frames = 10
horizon = 100
c_fail = 50.0
_, mc_return = compute_episode_returns(
num_frames=num_frames, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
)
expected = (-(num_frames - 1) / horizon) + (-c_fail / horizon)
assert mc_return[0] == pytest.approx(expected, abs=1e-5)
def test_discount_factor_less_than_one():
"""Discount factor < 1 should make earlier frames have smaller magnitude."""
_, mc_undiscounted = compute_episode_returns(
num_frames=20, success=True, c_fail=50.0, gamma=1.0, max_episode_length=20
)
_, mc_discounted = compute_episode_returns(
num_frames=20, success=True, c_fail=50.0, gamma=0.99, max_episode_length=20
)
assert abs(mc_discounted[0]) < abs(mc_undiscounted[0])
def test_single_frame_episode_success():
"""Single-frame successful episode: return should be 0."""
is_terminal, mc_return = compute_episode_returns(
num_frames=1, success=True, c_fail=50.0, gamma=1.0, max_episode_length=1
)
assert mc_return[0] == pytest.approx(0.0, abs=1e-6)
assert is_terminal[0] == True # noqa: E712
def test_single_frame_episode_failure():
"""Single-frame failed episode: return should be -C_fail/H."""
horizon = 100
c_fail = 50.0
is_terminal, mc_return = compute_episode_returns(
num_frames=1, success=False, c_fail=c_fail, gamma=1.0, max_episode_length=horizon
)
assert mc_return[0] == pytest.approx(-c_fail / horizon, abs=1e-5)
assert is_terminal[0] == True # noqa: E712
def test_horizon_normalization_scales_returns():
"""Larger horizon should scale down the per-step penalty."""
_, mc_small_h = compute_episode_returns(
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=10
)
_, mc_large_h = compute_episode_returns(
num_frames=10, success=True, c_fail=50.0, gamma=1.0, max_episode_length=100
)
assert abs(mc_large_h[0]) < abs(mc_small_h[0])
# ---------------------------------------------------------------------------
# Tests: _get_episode_success (in-memory PyArrow tables)
# ---------------------------------------------------------------------------
def test_default_success_overrides_column():
"""default_success should override any column value."""
table = pa.table({"next.success": [True, True, True]})
assert _get_episode_success(table, "next.success", default_success=False) is False
def test_reads_bool_column():
"""Should detect success via any() reduction over the column."""
table_success = pa.table({"next.success": [False, False, True]})
table_fail = pa.table({"next.success": [False, False, False]})
assert _get_episode_success(table_success, "next.success", None) is True
assert _get_episode_success(table_fail, "next.success", None) is False
def test_reads_int_column():
"""Should interpret integer success column (0/1) as bool via any()."""
table = pa.table({"task_success": [0, 0, 1]})
assert _get_episode_success(table, "task_success", None) is True
def test_all_zeros_means_failure():
"""An episode with all-zero success values is a failure."""
table = pa.table({"next.success": [0, 0, 0]})
assert _get_episode_success(table, "next.success", None) is False
def test_missing_column_defaults_to_true():
"""When success column is missing, assume success (demo data)."""
table = pa.table({"frame_index": [0, 1, 2]})
assert _get_episode_success(table, "next.success", None) is True
# ---------------------------------------------------------------------------
# Tests: parquet rewriting (integration, writes to disk)
# ---------------------------------------------------------------------------
def test_writes_columns_to_parquet(parquet_dataset):
"""The rewrite logic should add is_terminal and mc_return columns."""
root, parquet_path, episodes_meta = parquet_dataset
table_before = pq.read_table(parquet_path)
assert IS_TERMINAL_COL not in table_before.column_names
assert MC_RETURN_COL not in table_before.column_names
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table_after = pq.read_table(parquet_path)
assert IS_TERMINAL_COL in table_after.column_names
assert MC_RETURN_COL in table_after.column_names
def test_terminal_frames_correct(parquet_dataset):
"""Only the last frame of each episode should be terminal."""
root, parquet_path, episodes_meta = parquet_dataset
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
is_terminal = table.column(IS_TERMINAL_COL).to_pylist()
terminal_indices = [i for i, v in enumerate(is_terminal) if v]
assert terminal_indices == [9, 19, 29]
def test_success_episodes_return_zero_at_terminal(tmp_path):
"""Successful episodes (ep 0) should have mc_return=0 at terminal."""
num_episodes = 2
frames_per_ep = 5
root = tmp_path / "test_dataset"
data_dir = root / "data" / "chunk-000"
meta_dir = root / "meta"
data_dir.mkdir(parents=True)
meta_dir.mkdir(parents=True)
all_rows = []
episodes_meta = []
global_idx = 0
for ep in range(num_episodes):
ep_from = global_idx
is_successful = ep % 2 == 0
for frame in range(frames_per_ep):
is_last_frame = frame == frames_per_ep - 1
all_rows.append(
{
"episode_index": ep,
"frame_index": frame,
"index": global_idx,
"next.success": is_successful and is_last_frame,
}
)
global_idx += 1
episodes_meta.append(
{
"episode_index": ep,
"length": frames_per_ep,
"dataset_from_index": ep_from,
"dataset_to_index": global_idx,
}
)
table = pa.table(
{
"episode_index": [r["episode_index"] for r in all_rows],
"frame_index": [r["frame_index"] for r in all_rows],
"index": [r["index"] for r in all_rows],
"next.success": [r["next.success"] for r in all_rows],
}
)
parquet_path = data_dir / "episode_000000.parquet"
pq.write_table(table, parquet_path)
info = {
"codebase_version": "v3.0",
"total_episodes": num_episodes,
"total_frames": global_idx,
"fps": 30,
"features": {
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
"index": {"dtype": "int64", "shape": [1], "names": None},
"next.success": {"dtype": "bool", "shape": [1], "names": None},
},
}
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
mc_return = table.column(MC_RETURN_COL).to_pylist()
assert mc_return[4] == pytest.approx(0.0, abs=1e-5)
def test_failed_episodes_have_negative_terminal(tmp_path):
"""Failed episodes (ep 1) should have mc_return < 0 at terminal."""
num_episodes = 2
frames_per_ep = 5
root = tmp_path / "test_dataset"
data_dir = root / "data" / "chunk-000"
meta_dir = root / "meta"
data_dir.mkdir(parents=True)
meta_dir.mkdir(parents=True)
all_rows = []
episodes_meta = []
global_idx = 0
for ep in range(num_episodes):
ep_from = global_idx
is_successful = ep % 2 == 0
for frame in range(frames_per_ep):
is_last_frame = frame == frames_per_ep - 1
all_rows.append(
{
"episode_index": ep,
"frame_index": frame,
"index": global_idx,
"next.success": is_successful and is_last_frame,
}
)
global_idx += 1
episodes_meta.append(
{
"episode_index": ep,
"length": frames_per_ep,
"dataset_from_index": ep_from,
"dataset_to_index": global_idx,
}
)
table = pa.table(
{
"episode_index": [r["episode_index"] for r in all_rows],
"frame_index": [r["frame_index"] for r in all_rows],
"index": [r["index"] for r in all_rows],
"next.success": [r["next.success"] for r in all_rows],
}
)
parquet_path = data_dir / "episode_000000.parquet"
pq.write_table(table, parquet_path)
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=5, c_fail=50.0, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
mc_return = table.column(MC_RETURN_COL).to_pylist()
assert mc_return[9] < 0.0
def test_idempotent_with_force_flag(parquet_dataset):
"""Running twice with force should produce identical results."""
root, parquet_path, episodes_meta = parquet_dataset
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table1 = pq.read_table(parquet_path)
mc1 = table1.column(MC_RETURN_COL).to_pylist()
_rewrite_shard(parquet_path, episodes_meta, config)
table2 = pq.read_table(parquet_path)
mc2 = table2.column(MC_RETURN_COL).to_pylist()
assert mc1 == mc2
def test_skips_if_columns_exist_without_force(parquet_dataset):
"""Without force, existing columns should not be overwritten."""
root, parquet_path, episodes_meta = parquet_dataset
config = ComputeReturnsConfig(success_key="next.success", max_episode_length=10, force=True)
_rewrite_shard(parquet_path, episodes_meta, config)
table = pq.read_table(parquet_path)
original_mc = table.column(MC_RETURN_COL).to_pylist()
config_no_force = ComputeReturnsConfig(success_key="next.success", max_episode_length=20, force=False)
_rewrite_shard(parquet_path, episodes_meta, config_no_force)
table2 = pq.read_table(parquet_path)
assert table2.column(MC_RETURN_COL).to_pylist() == original_mc
def test_updates_info_json(parquet_dataset):
"""info.json should be updated with is_terminal and mc_return features."""
from lerobot.scripts.lerobot_compute_returns import _update_info_json
root, parquet_path, episodes_meta = parquet_dataset
_update_info_json(root, None)
info_path = root / "meta" / "info.json"
info = json.loads(info_path.read_text())
assert IS_TERMINAL_COL in info["features"]
assert MC_RETURN_COL in info["features"]
assert info["features"][IS_TERMINAL_COL]["dtype"] == "bool"
assert info["features"][MC_RETURN_COL]["dtype"] == "float32"