fix(is_depth): adding missing doctrings and is_depth arguments in video decoding functions

Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
This commit is contained in:
CarolinePascal
2026-06-12 18:04:45 +02:00
parent 56231b17d1
commit 971d7ea85b
+10 -8
View File
@@ -70,28 +70,29 @@ def decode_video_frames(
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
accepted for one release as an alias for "pyav" and will be removed in a future version.
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
Returns:
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
Currently supports torchcodec on cpu and pyav.
"""
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True)
if backend is None:
backend = get_safe_default_video_backend()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "pyav":
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@@ -121,8 +122,9 @@ def decode_video_frames_pyav(
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
[0, 1] range.
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
Otherwise, return float32 in [0, 1] range.
is_depth: Set to True if the video is a depth map (1 channel, uint12).
Returns:
torch.Tensor of shape (len(timestamps), C, H, W).
@@ -201,7 +203,7 @@ def decode_video_frames_pyav(
f"number of queried timestamps ({len(timestamps)})"
)
if return_uint8:
if return_uint8 or is_depth:
return closest_frames
# convert to the pytorch format which is float32 in [0,1] range (and channel first)