mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
fix(is_depth): adding missing doctrings and is_depth arguments in video decoding functions
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
This commit is contained in:
@@ -70,28 +70,29 @@ def decode_video_frames(
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||
torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -121,8 +122,9 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
|
||||
Otherwise, return float32 in [0, 1] range.
|
||||
is_depth: Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
@@ -201,7 +203,7 @@ def decode_video_frames_pyav(
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
if return_uint8 or is_depth:
|
||||
return closest_frames
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
|
||||
Reference in New Issue
Block a user