mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix decord
This commit is contained in:
@@ -39,7 +39,6 @@ import argparse
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import av
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -48,6 +47,22 @@ from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
|||||||
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
|
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
|
||||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||||
|
|
||||||
|
try:
|
||||||
|
import decord # type: ignore
|
||||||
|
|
||||||
|
_HAS_DECORD = True
|
||||||
|
except ImportError:
|
||||||
|
decord = None # type: ignore
|
||||||
|
_HAS_DECORD = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import av
|
||||||
|
|
||||||
|
_HAS_AV = True
|
||||||
|
except ImportError:
|
||||||
|
av = None # type: ignore
|
||||||
|
_HAS_AV = False
|
||||||
|
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
{
|
{
|
||||||
"name": "soar_put_green_stick_in_brown_bowl",
|
"name": "soar_put_green_stick_in_brown_bowl",
|
||||||
@@ -73,31 +88,72 @@ EXAMPLES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _extract_frames_av(video_path: Path, fps: float) -> np.ndarray:
|
def _extract_frames_decord(video_path: Path, fps: float) -> tuple[np.ndarray, str]:
|
||||||
"""Mirror upstream's ``extract_frames`` sampling logic using PyAV.
|
"""Mirror upstream's ``extract_frames`` sampling logic byte-for-byte with decord.
|
||||||
|
|
||||||
Upstream uses ``decord`` to read all frames, then samples
|
Upstream code (``third_party/robometer/scripts/example_inference.py``)::
|
||||||
``np.linspace(0, total_frames - 1, desired_frames, dtype=int)`` where
|
|
||||||
``desired_frames = round(total_frames * (fps / native_fps))``. We do the
|
vr = decord.VideoReader(video_path, num_threads=1)
|
||||||
same here so the per-frame outputs are directly comparable.
|
total_frames = len(vr)
|
||||||
|
native_fps = float(vr.get_avg_fps())
|
||||||
|
desired_frames = int(round(total_frames * (fps / native_fps)))
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
||||||
|
frames_array = vr.get_batch(frame_indices).asnumpy()
|
||||||
|
"""
|
||||||
|
vr = decord.VideoReader(str(video_path), num_threads=1)
|
||||||
|
total_frames = len(vr)
|
||||||
|
if total_frames == 0:
|
||||||
|
raise RuntimeError(f"No decodable frames in {video_path}.")
|
||||||
|
native_fps = float(vr.get_avg_fps()) or 1.0
|
||||||
|
desired_frames = max(1, int(round(total_frames * (fps / native_fps))))
|
||||||
|
desired_frames = min(desired_frames, total_frames)
|
||||||
|
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
||||||
|
frames = vr.get_batch(indices).asnumpy()
|
||||||
|
return frames, f"decord total={total_frames} native_fps={native_fps:.3f}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frames_av(video_path: Path, fps: float) -> tuple[np.ndarray, str]:
|
||||||
|
"""PyAV fallback for environments without decord. NOT byte-identical to upstream.
|
||||||
|
|
||||||
|
Upstream uses decord; using ffmpeg-via-av can produce a different
|
||||||
|
``total_frames`` for the same container (B-frame handling / packet timing),
|
||||||
|
which then propagates into a different ``desired_frames`` and different
|
||||||
|
sampled indices. Use this only for a smoke test; install ``decord`` for a
|
||||||
|
real parity check.
|
||||||
"""
|
"""
|
||||||
container = av.open(str(video_path))
|
container = av.open(str(video_path))
|
||||||
stream = container.streams.video[0]
|
stream = container.streams.video[0]
|
||||||
native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0)
|
native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0)
|
||||||
|
|
||||||
rgb_frames: list[np.ndarray] = []
|
rgb_frames: list[np.ndarray] = []
|
||||||
for frame in container.decode(stream):
|
for frame in container.decode(stream):
|
||||||
rgb_frames.append(frame.to_ndarray(format="rgb24"))
|
rgb_frames.append(frame.to_ndarray(format="rgb24"))
|
||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
total_frames = len(rgb_frames)
|
total_frames = len(rgb_frames)
|
||||||
if total_frames == 0:
|
if total_frames == 0:
|
||||||
raise RuntimeError(f"No decodable frames in {video_path}.")
|
raise RuntimeError(f"No decodable frames in {video_path}.")
|
||||||
|
|
||||||
desired_frames = max(1, int(round(total_frames * (fps / max(native_fps, 1e-6)))))
|
desired_frames = max(1, int(round(total_frames * (fps / max(native_fps, 1e-6)))))
|
||||||
desired_frames = min(desired_frames, total_frames)
|
desired_frames = min(desired_frames, total_frames)
|
||||||
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int)
|
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int)
|
||||||
return np.stack([rgb_frames[i] for i in indices])
|
frames = np.stack([rgb_frames[i] for i in indices])
|
||||||
|
return frames, f"av total={total_frames} native_fps={native_fps:.3f}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frames(video_path: Path, fps: float, prefer: str) -> tuple[np.ndarray, str]:
|
||||||
|
"""Pick a decoder. ``prefer`` is ``"decord"`` | ``"av"`` | ``"auto"``."""
|
||||||
|
if prefer == "decord":
|
||||||
|
if not _HAS_DECORD:
|
||||||
|
raise RuntimeError("decord requested but not installed (`uv pip install decord`).")
|
||||||
|
return _extract_frames_decord(video_path, fps)
|
||||||
|
if prefer == "av":
|
||||||
|
if not _HAS_AV:
|
||||||
|
raise RuntimeError("av requested but not installed.")
|
||||||
|
return _extract_frames_av(video_path, fps)
|
||||||
|
# auto
|
||||||
|
if _HAS_DECORD:
|
||||||
|
return _extract_frames_decord(video_path, fps)
|
||||||
|
if _HAS_AV:
|
||||||
|
return _extract_frames_av(video_path, fps)
|
||||||
|
raise RuntimeError("No video decoder available (install `decord` or `av`).")
|
||||||
|
|
||||||
|
|
||||||
def _run_lerobot(
|
def _run_lerobot(
|
||||||
@@ -160,10 +216,37 @@ def main() -> int:
|
|||||||
default=3.0,
|
default=3.0,
|
||||||
help="Sampling fps (default: 3, matching the upstream README).",
|
help="Sampling fps (default: 3, matching the upstream README).",
|
||||||
)
|
)
|
||||||
parser.add_argument("--atol", type=float, default=1e-3)
|
parser.add_argument(
|
||||||
parser.add_argument("--rtol", type=float, default=1e-2)
|
"--decoder",
|
||||||
|
choices=("auto", "decord", "av"),
|
||||||
|
default="auto",
|
||||||
|
help=(
|
||||||
|
"Video decoder. Default: ``auto`` prefers decord (byte-identical to upstream) "
|
||||||
|
"and falls back to av. Force with --decoder decord for a clean parity check."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--atol",
|
||||||
|
type=float,
|
||||||
|
default=5e-3,
|
||||||
|
help="Absolute tolerance for allclose. Default 5e-3 covers bf16 round-trip + sigmoid amplification.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rtol",
|
||||||
|
type=float,
|
||||||
|
default=1e-2,
|
||||||
|
help="Relative tolerance for allclose.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.decoder == "av" or (args.decoder == "auto" and not _HAS_DECORD):
|
||||||
|
print(
|
||||||
|
"WARNING: using PyAV decoder. PyAV's total-frame count can differ from decord's, "
|
||||||
|
"which propagates into a different number of sampled frames and breaks byte parity. "
|
||||||
|
"Run `uv pip install decord` and re-run for a clean check.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
examples_dir = args.examples_dir.resolve()
|
examples_dir = args.examples_dir.resolve()
|
||||||
if not examples_dir.is_dir():
|
if not examples_dir.is_dir():
|
||||||
print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr)
|
print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr)
|
||||||
@@ -200,8 +283,8 @@ def main() -> int:
|
|||||||
|
|
||||||
print(f"\n=== {ex['name']} ===")
|
print(f"\n=== {ex['name']} ===")
|
||||||
print(f" task: {ex['task']!r}")
|
print(f" task: {ex['task']!r}")
|
||||||
frames = _extract_frames_av(video_path, fps=args.fps)
|
frames, decoder_info = _extract_frames(video_path, fps=args.fps, prefer=args.decoder)
|
||||||
print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape}")
|
print(f" decoded {frames.shape[0]} frames @ fps={args.fps}; shape={frames.shape} [{decoder_info}]")
|
||||||
|
|
||||||
progress, success = _run_lerobot(model, encoder, frames, ex["task"])
|
progress, success = _run_lerobot(model, encoder, frames, ex["task"])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user