mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat(depth): wire StreamingVideoEncoder + writer to depth encoder
This commit is contained in:
@@ -31,7 +31,12 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, depth_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
@@ -509,7 +514,12 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
self._meta.update_video_info(
|
||||
video_key,
|
||||
video_encoder=self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
|
||||
@@ -330,10 +330,20 @@ def validate_feature_image_or_video(
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
actual_shape = tuple(value.shape)
|
||||
expected = tuple(expected_shape)
|
||||
if len(expected) == 2:
|
||||
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
|
||||
h, w = expected
|
||||
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
|
||||
if not valid:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
|
||||
elif len(expected) == 3:
|
||||
c, h, w = expected
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, DepthEncoderConfig
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
|
||||
@@ -37,13 +37,15 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.configs import (
|
||||
VideoEncoderConfig,
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -521,9 +523,7 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
video_encoder: VideoEncoderConfig,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
@@ -531,9 +531,8 @@ class _CameraEncoderThread(threading.Thread):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.video_encoder = video_encoder
|
||||
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
@@ -561,12 +560,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -574,15 +573,29 @@ class _CameraEncoderThread(threading.Thread):
|
||||
height, width = frame_data.shape[:2]
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream = container.add_stream(
|
||||
self.video_encoder.vcodec,
|
||||
self.fps,
|
||||
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||
)
|
||||
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
if not self.is_depth:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
else:
|
||||
video_frame = quantize_depth(
|
||||
frame_data,
|
||||
depth_min=self.video_encoder.depth_min,
|
||||
depth_max=self.video_encoder.depth_max,
|
||||
shift=self.video_encoder.shift,
|
||||
use_log=self.video_encoder.use_log,
|
||||
video_backend=self.video_encoder.video_backend,
|
||||
)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
@@ -692,13 +705,10 @@ class StreamingVideoEncoder:
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||
codec_options = encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=encoder.vcodec,
|
||||
pix_fmt=encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
video_encoder=encoder,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -908,13 +918,13 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
@@ -949,12 +959,13 @@ def get_video_info(
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
if video_encoder is not None:
|
||||
for field_name, field_value in asdict(video_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
|
||||
Reference in New Issue
Block a user