mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
fix decord
This commit is contained in:
@@ -39,7 +39,6 @@ import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -48,6 +47,22 @@ from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
|
||||
try:
|
||||
import decord # type: ignore
|
||||
|
||||
_HAS_DECORD = True
|
||||
except ImportError:
|
||||
decord = None # type: ignore
|
||||
_HAS_DECORD = False
|
||||
|
||||
try:
|
||||
import av
|
||||
|
||||
_HAS_AV = True
|
||||
except ImportError:
|
||||
av = None # type: ignore
|
||||
_HAS_AV = False
|
||||
|
||||
EXAMPLES = [
|
||||
{
|
||||
"name": "soar_put_green_stick_in_brown_bowl",
|
||||
@@ -73,31 +88,72 @@ EXAMPLES = [
|
||||
]
|
||||
|
||||
|
||||
def _extract_frames_av(video_path: Path, fps: float) -> np.ndarray:
|
||||
"""Mirror upstream's ``extract_frames`` sampling logic using PyAV.
|
||||
def _extract_frames_decord(video_path: Path, fps: float) -> tuple[np.ndarray, str]:
|
||||
"""Mirror upstream's ``extract_frames`` sampling logic byte-for-byte with decord.
|
||||
|
||||
Upstream uses ``decord`` to read all frames, then samples
|
||||
``np.linspace(0, total_frames - 1, desired_frames, dtype=int)`` where
|
||||
``desired_frames = round(total_frames * (fps / native_fps))``. We do the
|
||||
same here so the per-frame outputs are directly comparable.
|
||||
Upstream code (``third_party/robometer/scripts/example_inference.py``)::
|
||||
|
||||
vr = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frames = len(vr)
|
||||
native_fps = float(vr.get_avg_fps())
|
||||
desired_frames = int(round(total_frames * (fps / native_fps)))
|
||||
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
||||
frames_array = vr.get_batch(frame_indices).asnumpy()
|
||||
"""
|
||||
vr = decord.VideoReader(str(video_path), num_threads=1)
|
||||
total_frames = len(vr)
|
||||
if total_frames == 0:
|
||||
raise RuntimeError(f"No decodable frames in {video_path}.")
|
||||
native_fps = float(vr.get_avg_fps()) or 1.0
|
||||
desired_frames = max(1, int(round(total_frames * (fps / native_fps))))
|
||||
desired_frames = min(desired_frames, total_frames)
|
||||
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
||||
frames = vr.get_batch(indices).asnumpy()
|
||||
return frames, f"decord total={total_frames} native_fps={native_fps:.3f}"
|
||||
|
||||
|
||||
def _extract_frames_av(video_path: Path, fps: float) -> tuple[np.ndarray, str]:
|
||||
"""PyAV fallback for environments without decord. NOT byte-identical to upstream.
|
||||
|
||||
Upstream uses decord; using ffmpeg-via-av can produce a different
|
||||
``total_frames`` for the same container (B-frame handling / packet timing),
|
||||
which then propagates into a different ``desired_frames`` and different
|
||||
sampled indices. Use this only for a smoke test; install ``decord`` for a
|
||||
real parity check.
|
||||
"""
|
||||
container = av.open(str(video_path))
|
||||
stream = container.streams.video[0]
|
||||
native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0)
|
||||
|
||||
rgb_frames: list[np.ndarray] = []
|
||||
for frame in container.decode(stream):
|
||||
rgb_frames.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
|
||||
total_frames = len(rgb_frames)
|
||||
if total_frames == 0:
|
||||
raise RuntimeError(f"No decodable frames in {video_path}.")
|
||||
|
||||
desired_frames = max(1, int(round(total_frames * (fps / max(native_fps, 1e-6)))))
|
||||
desired_frames = min(desired_frames, total_frames)
|
||||
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int)
|
||||
return np.stack([rgb_frames[i] for i in indices])
|
||||
frames = np.stack([rgb_frames[i] for i in indices])
|
||||
return frames, f"av total={total_frames} native_fps={native_fps:.3f}"
|
||||
|
||||
|
||||
def _extract_frames(video_path: Path, fps: float, prefer: str) -> tuple[np.ndarray, str]:
|
||||
"""Pick a decoder. ``prefer`` is ``"decord"`` | ``"av"`` | ``"auto"``."""
|
||||
if prefer == "decord":
|
||||
if not _HAS_DECORD:
|
||||
raise RuntimeError("decord requested but not installed (`uv pip install decord`).")
|
||||
return _extract_frames_decord(video_path, fps)
|
||||
if prefer == "av":
|
||||
if not _HAS_AV:
|
||||
raise RuntimeError("av requested but not installed.")
|
||||
return _extract_frames_av(video_path, fps)
|
||||
# auto
|
||||
if _HAS_DECORD:
|
||||
return _extract_frames_decord(video_path, fps)
|
||||
if _HAS_AV:
|
||||
return _extract_frames_av(video_path, fps)
|
||||
raise RuntimeError("No video decoder available (install `decord` or `av`).")
|
||||
|
||||
|
||||
def _run_lerobot(
|
||||
@@ -160,10 +216,37 @@ def main() -> int:
|
||||
default=3.0,
|
||||
help="Sampling fps (default: 3, matching the upstream README).",
|
||||
)
|
||||
parser.add_argument("--atol", type=float, default=1e-3)
|
||||
parser.add_argument("--rtol", type=float, default=1e-2)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
choices=("auto", "decord", "av"),
|
||||
default="auto",
|
||||
help=(
|
||||
"Video decoder. Default: ``auto`` prefers decord (byte-identical to upstream) "
|
||||
"and falls back to av. Force with --decoder decord for a clean parity check."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol",
|
||||
type=float,
|
||||
default=5e-3,
|
||||
help="Absolute tolerance for allclose. Default 5e-3 covers bf16 round-trip + sigmoid amplification.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Relative tolerance for allclose.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.decoder == "av" or (args.decoder == "auto" and not _HAS_DECORD):
|
||||
print(
|
||||
"WARNING: using PyAV decoder. PyAV's total-frame count can differ from decord's, "
|
||||
"which propagates into a different number of sampled frames and breaks byte parity. "
|
||||
"Run `uv pip install decord` and re-run for a clean check.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
examples_dir = args.examples_dir.resolve()
|
||||
if not examples_dir.is_dir():
|
||||
print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr)
|
||||
@@ -200,8 +283,8 @@ def main() -> int:
|
||||
|
||||
print(f"\n=== {ex['name']} ===")
|
||||
print(f" task: {ex['task']!r}")
|
||||
frames = _extract_frames_av(video_path, fps=args.fps)
|
||||
print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape}")
|
||||
frames, decoder_info = _extract_frames(video_path, fps=args.fps, prefer=args.decoder)
|
||||
print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape} [{decoder_info}]")
|
||||
|
||||
progress, success = _run_lerobot(model, encoder, frames, ex["task"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user