feat(depth): wire StreamingVideoEncoder + writer to depth encoder

This commit is contained in:
CarolinePascal
2026-05-19 23:23:27 +02:00
parent b4c31f0f67
commit d39698da0f
5 changed files with 62 additions and 37 deletions
+12 -2
View File
@@ -31,7 +31,12 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, depth_encoder_defaults
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata
@@ -509,7 +514,12 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
self._meta.update_video_info(
video_key,
video_encoder=self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
)
write_info(self._meta.info, self._meta.root)
metadata = {
+14 -4
View File
@@ -330,10 +330,20 @@ def validate_feature_image_or_video(
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
actual_shape = tuple(value.shape)
expected = tuple(expected_shape)
if len(expected) == 2:
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
h, w = expected
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
if not valid:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
elif len(expected) == 3:
c, h, w = expected
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
else:
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
+1 -1
View File
@@ -24,7 +24,7 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.configs import VideoEncoderConfig, DepthEncoderConfig
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
+32 -21
View File
@@ -37,13 +37,15 @@ from datasets.features.features import register_feature
from PIL import Image
from lerobot.configs import (
VideoEncoderConfig,
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import quantize_depth
logger = logging.getLogger(__name__)
@@ -521,9 +523,7 @@ class _CameraEncoderThread(threading.Thread):
self,
video_path: Path,
fps: int,
vcodec: str,
pix_fmt: str,
codec_options: dict[str, str],
video_encoder: VideoEncoderConfig,
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
@@ -531,9 +531,8 @@ class _CameraEncoderThread(threading.Thread):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.codec_options = codec_options
self.video_encoder = video_encoder
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
@@ -561,12 +560,12 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8:
if not self.is_depth and frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
@@ -574,15 +573,29 @@ class _CameraEncoderThread(threading.Thread):
height, width = frame_data.shape[:2]
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
output_stream.pix_fmt = self.pix_fmt
output_stream = container.add_stream(
self.video_encoder.vcodec,
self.fps,
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
)
output_stream.pix_fmt = self.video_encoder.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
if not self.is_depth:
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
else:
video_frame = quantize_depth(
frame_data,
depth_min=self.video_encoder.depth_min,
depth_max=self.video_encoder.depth_max,
shift=self.video_encoder.shift,
use_log=self.video_encoder.use_log,
video_backend=self.video_encoder.video_backend,
)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
@@ -692,13 +705,10 @@ class StreamingVideoEncoder:
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
codec_options = encoder.get_codec_options(self._encoder_threads, as_strings=True)
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=encoder.vcodec,
pix_fmt=encoder.pix_fmt,
codec_options=codec_options,
video_encoder=encoder,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
@@ -908,13 +918,13 @@ def get_audio_info(video_path: Path | str) -> dict:
def get_video_info(
video_path: Path | str,
camera_encoder: VideoEncoderConfig | None = None,
video_encoder: VideoEncoderConfig | None = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
camera_encoder: If provided, record the exact encoder settings used to encode this
video_encoder: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence encoder fields are only written for keys
not already populated from the video file itself.
"""
@@ -949,12 +959,13 @@ def get_video_info(
video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided
if camera_encoder is not None:
for field_name, field_value in asdict(camera_encoder).items():
if video_encoder is not None:
for field_name, field_value in asdict(video_encoder).items():
# vcodec is already populated from the video stream
if field_name == "vcodec":
continue
video_info.setdefault(f"video.{field_name}", field_value)
video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
return video_info
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=fps,
vcodec=enc_cfg.vcodec,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
video_encoder=enc_cfg,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=fps,
vcodec=enc_cfg.vcodec,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
video_encoder=enc_cfg,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=fps,
vcodec=enc_cfg.vcodec,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
video_encoder=enc_cfg,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,