diff --git a/scripts/parity_robometer_upstream_examples.py b/scripts/parity_robometer_upstream_examples.py index 25dd5b906..ee10f16f5 100644 --- a/scripts/parity_robometer_upstream_examples.py +++ b/scripts/parity_robometer_upstream_examples.py @@ -39,7 +39,6 @@ import argparse import sys from pathlib import Path -import av import numpy as np import torch @@ -48,6 +47,22 @@ from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep +try: + import decord # type: ignore + + _HAS_DECORD = True +except ImportError: + decord = None # type: ignore + _HAS_DECORD = False + +try: + import av + + _HAS_AV = True +except ImportError: + av = None # type: ignore + _HAS_AV = False + EXAMPLES = [ { "name": "soar_put_green_stick_in_brown_bowl", @@ -73,31 +88,72 @@ EXAMPLES = [ ] -def _extract_frames_av(video_path: Path, fps: float) -> np.ndarray: - """Mirror upstream's ``extract_frames`` sampling logic using PyAV. +def _extract_frames_decord(video_path: Path, fps: float) -> tuple[np.ndarray, str]: + """Mirror upstream's ``extract_frames`` sampling logic byte-for-byte with decord. - Upstream uses ``decord`` to read all frames, then samples - ``np.linspace(0, total_frames - 1, desired_frames, dtype=int)`` where - ``desired_frames = round(total_frames * (fps / native_fps))``. We do the - same here so the per-frame outputs are directly comparable. + Upstream code (``third_party/robometer/scripts/example_inference.py``):: + + vr = decord.VideoReader(video_path, num_threads=1) + total_frames = len(vr) + native_fps = float(vr.get_avg_fps()) + desired_frames = int(round(total_frames * (fps / native_fps))) + frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist() + frames_array = vr.get_batch(frame_indices).asnumpy() + """ + vr = decord.VideoReader(str(video_path), num_threads=1) + total_frames = len(vr) + if total_frames == 0: + raise RuntimeError(f"No decodable frames in {video_path}.") + native_fps = float(vr.get_avg_fps()) or 1.0 + desired_frames = max(1, int(round(total_frames * (fps / native_fps)))) + desired_frames = min(desired_frames, total_frames) + indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist() + frames = vr.get_batch(indices).asnumpy() + return frames, f"decord total={total_frames} native_fps={native_fps:.3f}" + + +def _extract_frames_av(video_path: Path, fps: float) -> tuple[np.ndarray, str]: + """PyAV fallback for environments without decord. NOT byte-identical to upstream. + + Upstream uses decord; using ffmpeg-via-av can produce a different + ``total_frames`` for the same container (B-frame handling / packet timing), + which then propagates into a different ``desired_frames`` and different + sampled indices. Use this only for a smoke test; install ``decord`` for a + real parity check. """ container = av.open(str(video_path)) stream = container.streams.video[0] native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0) - rgb_frames: list[np.ndarray] = [] for frame in container.decode(stream): rgb_frames.append(frame.to_ndarray(format="rgb24")) container.close() - total_frames = len(rgb_frames) if total_frames == 0: raise RuntimeError(f"No decodable frames in {video_path}.") - desired_frames = max(1, int(round(total_frames * (fps / max(native_fps, 1e-6))))) desired_frames = min(desired_frames, total_frames) indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int) - return np.stack([rgb_frames[i] for i in indices]) + frames = np.stack([rgb_frames[i] for i in indices]) + return frames, f"av total={total_frames} native_fps={native_fps:.3f}" + + +def _extract_frames(video_path: Path, fps: float, prefer: str) -> tuple[np.ndarray, str]: + """Pick a decoder. ``prefer`` is ``"decord"`` | ``"av"`` | ``"auto"``.""" + if prefer == "decord": + if not _HAS_DECORD: + raise RuntimeError("decord requested but not installed (`uv pip install decord`).") + return _extract_frames_decord(video_path, fps) + if prefer == "av": + if not _HAS_AV: + raise RuntimeError("av requested but not installed.") + return _extract_frames_av(video_path, fps) + # auto + if _HAS_DECORD: + return _extract_frames_decord(video_path, fps) + if _HAS_AV: + return _extract_frames_av(video_path, fps) + raise RuntimeError("No video decoder available (install `decord` or `av`).") def _run_lerobot( @@ -160,10 +216,37 @@ def main() -> int: default=3.0, help="Sampling fps (default: 3, matching the upstream README).", ) - parser.add_argument("--atol", type=float, default=1e-3) - parser.add_argument("--rtol", type=float, default=1e-2) + parser.add_argument( + "--decoder", + choices=("auto", "decord", "av"), + default="auto", + help=( + "Video decoder. Default: ``auto`` prefers decord (byte-identical to upstream) " + "and falls back to av. Force with --decoder decord for a clean parity check." + ), + ) + parser.add_argument( + "--atol", + type=float, + default=5e-3, + help="Absolute tolerance for allclose. Default 5e-3 covers bf16 round-trip + sigmoid amplification.", + ) + parser.add_argument( + "--rtol", + type=float, + default=1e-2, + help="Relative tolerance for allclose.", + ) args = parser.parse_args() + if args.decoder == "av" or (args.decoder == "auto" and not _HAS_DECORD): + print( + "WARNING: using PyAV decoder. PyAV's total-frame count can differ from decord's, " + "which propagates into a different number of sampled frames and breaks byte parity. " + "Run `uv pip install decord` and re-run for a clean check.", + file=sys.stderr, + ) + examples_dir = args.examples_dir.resolve() if not examples_dir.is_dir(): print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr) @@ -200,8 +283,8 @@ def main() -> int: print(f"\n=== {ex['name']} ===") print(f" task: {ex['task']!r}") - frames = _extract_frames_av(video_path, fps=args.fps) - print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape}") + frames, decoder_info = _extract_frames(video_path, fps=args.fps, prefer=args.decoder) + print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape} [{decoder_info}]") progress, success = _run_lerobot(model, encoder, frames, ex["task"])