Files
lerobot/src/lerobot/datasets/video_utils.py
T

678 lines
25 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import importlib
import logging
import shutil
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from typing import Any, ClassVar
import av
import fsspec
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str | None = None,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
For more info on video decoding, see `benchmark/video/README.md`
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
video_path = str(video_path)
# set backend
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(video_path, "video")
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = min(timestamps)
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)
# load all frames until last requested frame
loaded_frames = []
loaded_ts = []
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
if backend == "pyav":
reader.container.close()
reader = None
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
self._lock = Lock()
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
video_path = str(video_path)
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0]
def clear(self):
"""Clear the cache and close file handles."""
with self._lock:
for _, file_handle in self._cache.values():
file_handle.close()
self._cache.clear()
def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
return len(self._cache)
class FrameTimestampError(ValueError):
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
pass
_default_decoder_cache = VideoDecoderCache()
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Args:
video_path: Path to the video file.
timestamps: List of timestamps to extract frames.
tolerance_s: Allowed deviation in seconds for frame retrieval.
log_loaded_timestamps: Whether to log loaded timestamps.
decoder_cache: Optional decoder cache instance. Uses default if None.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
if decoder_cache is None:
decoder_cache = _default_decoder_cache
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
loaded_ts = []
loaded_frames = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
if not len(timestamps) == len(closest_frames):
raise FrameTimestampError(
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
)
return closest_frames
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = True,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
# Check encoder availability
if vcodec not in ["h264", "hevc", "libsvtav1"]:
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
if video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
return
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logging.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
# Get input frames
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
# Define video codec options
video_options = {}
if g is not None:
video_options["g"] = str(g)
if crf is not None:
video_options["crf"] = str(crf)
#TEMPORARY FIX
video_options["preset"] = "12"
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
# Set logging level
if log_level is not None:
# "While less efficient, it is generally preferable to modify logging with Python's logging"
logging.getLogger("libav").setLevel(log_level)
# Create and open output file (overwrite by default)
with av.open(str(video_path), "w") as output:
output_stream = output.add_stream(vcodec, fps, options=video_options)
output_stream.pix_fmt = pix_fmt
output_stream.width = width
output_stream.height = height
# Loop through input frames and encode them
for input_data in input_list:
with Image.open(input_data) as input_image:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()
if packet:
output.mux(packet)
# Reset logging level
if log_level is not None:
av.logging.restore_default_callback()
if not video_path.exists():
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
):
"""
Concatenate multiple video files into a single video file using pyav.
This function takes a list of video input file paths and concatenates them into a single
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
concatenation without re-encoding.
Args:
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
Note:
- Creates a temporary directory for intermediate files that is cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
codec, resolution, and frame rate for proper concatenation.
"""
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
# Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
for input_path in input_video_paths:
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
tmp_concatenate_file.flush()
tmp_concatenate_path = tmp_concatenate_file.name
# Create input and output containers
input_container = av.open(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
) # safe = 0 allows absolute paths as well as relative paths
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
) # faststart is to move the metadata to the beginning of the file to speed up loading
# Replicate input streams in output container
stream_map = {}
for input_stream in input_container.streams:
if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams
stream_map[input_stream.index] = output_container.add_stream_from_template(
template=input_stream, opaque=True
)
# set the time base to the input stream time base (missing in the codec context)
stream_map[input_stream.index].time_base = input_stream.time_base
# Demux + remux packets (no re-encode)
for packet in input_container.demux():
# Skip packets from un-mapped streams
if packet.stream.index not in stream_map:
continue
# Skip demux flushing packets
if packet.dts is None:
continue
output_stream = stream_map[packet.stream.index]
packet.stream = output_stream
output_container.mux(packet)
input_container.close()
output_container.close()
shutil.move(tmp_output_video_path, output_video_path)
Path(tmp_concatenate_path).unlink()
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting audio stream information
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
# Reset logging level
av.logging.restore_default_callback()
return audio_info
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting video stream information
video_info = {}
with av.open(str(video_path), "r") as video_file:
try:
video_stream = video_file.streams.video[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {}
video_info["video.height"] = video_stream.height
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
# Reset logging level
av.logging.restore_default_callback()
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.
Args:
video_path: Path to the video file.
Returns:
Duration of the video in seconds.
"""
with av.open(str(video_path)) as container:
# Get the first video stream
video_stream = container.streams.video[0]
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
if video_stream.duration is not None:
duration = float(video_stream.duration * video_stream.time_base)
else:
# Fallback to container duration if stream duration is not available
duration = float(container.duration / av.time_base)
return duration
class VideoEncodingManager:
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args:
dataset: The LeRobotDataset instance
"""
def __init__(self, dataset):
self.dataset = dataset
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Handle any remaining episodes that haven't been batch encoded
if self.dataset.episodes_since_last_encoding > 0:
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset._batch_save_episode_video(start_ep, end_ep)
# Finalize the dataset to properly close all writers
self.dataset.finalize()
# Clean up episode images if recording was interrupted
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
# Check for any remaining PNG files
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Only remove the images directory if no PNG files remain
if img_dir.exists():
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception