fix decord

This commit is contained in:
Khalil Meftah
2026-05-18 10:39:51 +02:00
parent 34274c6f70
commit 0164725af8
+98 -15
View File
@@ -39,7 +39,6 @@ import argparse
import sys
from pathlib import Path
import av
import numpy as np
import torch
@@ -48,6 +47,22 @@ from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
try:
import decord # type: ignore
_HAS_DECORD = True
except ImportError:
decord = None # type: ignore
_HAS_DECORD = False
try:
import av
_HAS_AV = True
except ImportError:
av = None # type: ignore
_HAS_AV = False
EXAMPLES = [
{
"name": "soar_put_green_stick_in_brown_bowl",
@@ -73,31 +88,72 @@ EXAMPLES = [
]
def _extract_frames_av(video_path: Path, fps: float) -> np.ndarray:
"""Mirror upstream's ``extract_frames`` sampling logic using PyAV.
def _extract_frames_decord(video_path: Path, fps: float) -> tuple[np.ndarray, str]:
"""Mirror upstream's ``extract_frames`` sampling logic byte-for-byte with decord.
Upstream uses ``decord`` to read all frames, then samples
``np.linspace(0, total_frames - 1, desired_frames, dtype=int)`` where
``desired_frames = round(total_frames * (fps / native_fps))``. We do the
same here so the per-frame outputs are directly comparable.
Upstream code (``third_party/robometer/scripts/example_inference.py``)::
vr = decord.VideoReader(video_path, num_threads=1)
total_frames = len(vr)
native_fps = float(vr.get_avg_fps())
desired_frames = int(round(total_frames * (fps / native_fps)))
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
frames_array = vr.get_batch(frame_indices).asnumpy()
"""
vr = decord.VideoReader(str(video_path), num_threads=1)
total_frames = len(vr)
if total_frames == 0:
raise RuntimeError(f"No decodable frames in {video_path}.")
native_fps = float(vr.get_avg_fps()) or 1.0
desired_frames = max(1, int(round(total_frames * (fps / native_fps))))
desired_frames = min(desired_frames, total_frames)
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
frames = vr.get_batch(indices).asnumpy()
return frames, f"decord total={total_frames} native_fps={native_fps:.3f}"
def _extract_frames_av(video_path: Path, fps: float) -> tuple[np.ndarray, str]:
"""PyAV fallback for environments without decord. NOT byte-identical to upstream.
Upstream uses decord; using ffmpeg-via-av can produce a different
``total_frames`` for the same container (B-frame handling / packet timing),
which then propagates into a different ``desired_frames`` and different
sampled indices. Use this only for a smoke test; install ``decord`` for a
real parity check.
"""
container = av.open(str(video_path))
stream = container.streams.video[0]
native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0)
rgb_frames: list[np.ndarray] = []
for frame in container.decode(stream):
rgb_frames.append(frame.to_ndarray(format="rgb24"))
container.close()
total_frames = len(rgb_frames)
if total_frames == 0:
raise RuntimeError(f"No decodable frames in {video_path}.")
desired_frames = max(1, int(round(total_frames * (fps / max(native_fps, 1e-6)))))
desired_frames = min(desired_frames, total_frames)
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int)
return np.stack([rgb_frames[i] for i in indices])
frames = np.stack([rgb_frames[i] for i in indices])
return frames, f"av total={total_frames} native_fps={native_fps:.3f}"
def _extract_frames(video_path: Path, fps: float, prefer: str) -> tuple[np.ndarray, str]:
"""Pick a decoder. ``prefer`` is ``"decord"`` | ``"av"`` | ``"auto"``."""
if prefer == "decord":
if not _HAS_DECORD:
raise RuntimeError("decord requested but not installed (`uv pip install decord`).")
return _extract_frames_decord(video_path, fps)
if prefer == "av":
if not _HAS_AV:
raise RuntimeError("av requested but not installed.")
return _extract_frames_av(video_path, fps)
# auto
if _HAS_DECORD:
return _extract_frames_decord(video_path, fps)
if _HAS_AV:
return _extract_frames_av(video_path, fps)
raise RuntimeError("No video decoder available (install `decord` or `av`).")
def _run_lerobot(
@@ -160,10 +216,37 @@ def main() -> int:
default=3.0,
help="Sampling fps (default: 3, matching the upstream README).",
)
parser.add_argument("--atol", type=float, default=1e-3)
parser.add_argument("--rtol", type=float, default=1e-2)
parser.add_argument(
"--decoder",
choices=("auto", "decord", "av"),
default="auto",
help=(
"Video decoder. Default: ``auto`` prefers decord (byte-identical to upstream) "
"and falls back to av. Force with --decoder decord for a clean parity check."
),
)
parser.add_argument(
"--atol",
type=float,
default=5e-3,
help="Absolute tolerance for allclose. Default 5e-3 covers bf16 round-trip + sigmoid amplification.",
)
parser.add_argument(
"--rtol",
type=float,
default=1e-2,
help="Relative tolerance for allclose.",
)
args = parser.parse_args()
if args.decoder == "av" or (args.decoder == "auto" and not _HAS_DECORD):
print(
"WARNING: using PyAV decoder. PyAV's total-frame count can differ from decord's, "
"which propagates into a different number of sampled frames and breaks byte parity. "
"Run `uv pip install decord` and re-run for a clean check.",
file=sys.stderr,
)
examples_dir = args.examples_dir.resolve()
if not examples_dir.is_dir():
print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr)
@@ -200,8 +283,8 @@ def main() -> int:
print(f"\n=== {ex['name']} ===")
print(f" task: {ex['task']!r}")
frames = _extract_frames_av(video_path, fps=args.fps)
print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape}")
frames, decoder_info = _extract_frames(video_path, fps=args.fps, prefer=args.decoder)
print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape} [{decoder_info}]")
progress, success = _run_lerobot(model, encoder, frames, ex["task"])