diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index d3c811864..df8a9daa9 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -31,7 +31,12 @@ import PIL.Image import pyarrow.parquet as pq import torch -from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, depth_encoder_defaults +from lerobot.configs import ( + DepthEncoderConfig, + VideoEncoderConfig, + camera_encoder_defaults, + depth_encoder_defaults, +) from .compute_stats import compute_episode_stats from .dataset_metadata import LeRobotDatasetMetadata @@ -509,7 +514,12 @@ class DatasetWriter: # Update video info (only needed when first episode is encoded) if episode_index == 0: - self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder) + self._meta.update_video_info( + video_key, + video_encoder=self._depth_encoder + if video_key in self._meta.depth_keys + else self._camera_encoder, + ) write_info(self._meta.info, self._meta.root) metadata = { diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index d5a550a4c..06467ac3a 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -330,10 +330,20 @@ def validate_feature_image_or_video( # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): - actual_shape = value.shape - c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + actual_shape = tuple(value.shape) + expected = tuple(expected_shape) + if len(expected) == 2: + # Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1) + h, w = expected + valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)} + if not valid: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n" + elif len(expected) == 3: + c, h, w = expected + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + else: + error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n" elif isinstance(value, PILImage.Image): pass else: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 615be5df7..8ef6f22c6 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -24,7 +24,7 @@ import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.configs import VideoEncoderConfig, DepthEncoderConfig +from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 9fd29e138..dd7cf7ee7 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -37,13 +37,15 @@ from datasets.features.features import register_feature from PIL import Image from lerobot.configs import ( - VideoEncoderConfig, DepthEncoderConfig, + VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults, ) from lerobot.utils.import_utils import get_safe_default_video_backend +from .depth_utils import quantize_depth + logger = logging.getLogger(__name__) @@ -521,9 +523,7 @@ class _CameraEncoderThread(threading.Thread): self, video_path: Path, fps: int, - vcodec: str, - pix_fmt: str, - codec_options: dict[str, str], + video_encoder: VideoEncoderConfig, frame_queue: queue.Queue, result_queue: queue.Queue, stop_event: threading.Event, @@ -531,9 +531,8 @@ class _CameraEncoderThread(threading.Thread): super().__init__(daemon=True) self.video_path = video_path self.fps = fps - self.vcodec = vcodec - self.pix_fmt = pix_fmt - self.codec_options = codec_options + self.video_encoder = video_encoder + self.is_depth = isinstance(video_encoder, DepthEncoderConfig) self.frame_queue = frame_queue self.result_queue = result_queue self.stop_event = stop_event @@ -561,12 +560,12 @@ class _CameraEncoderThread(threading.Thread): # Sentinel: flush and close break - # Ensure HWC uint8 numpy array + # Ensure HWC (RGB or depth) uint8 (RGB only) numpy array if isinstance(frame_data, np.ndarray): if frame_data.ndim == 3 and frame_data.shape[0] == 3: # CHW -> HWC frame_data = frame_data.transpose(1, 2, 0) - if frame_data.dtype != np.uint8: + if not self.is_depth and frame_data.dtype != np.uint8: frame_data = (frame_data * 255).astype(np.uint8) # Open container on first frame (to get width/height) @@ -574,15 +573,29 @@ class _CameraEncoderThread(threading.Thread): height, width = frame_data.shape[:2] Path(self.video_path).parent.mkdir(parents=True, exist_ok=True) container = av.open(str(self.video_path), "w") - output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options) - output_stream.pix_fmt = self.pix_fmt + output_stream = container.add_stream( + self.video_encoder.vcodec, + self.fps, + options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True), + ) + output_stream.pix_fmt = self.video_encoder.pix_fmt output_stream.width = width output_stream.height = height output_stream.time_base = Fraction(1, self.fps) # Encode frame with explicit timestamps - pil_img = Image.fromarray(frame_data) - video_frame = av.VideoFrame.from_image(pil_img) + if not self.is_depth: + pil_img = Image.fromarray(frame_data) + video_frame = av.VideoFrame.from_image(pil_img) + else: + video_frame = quantize_depth( + frame_data, + depth_min=self.video_encoder.depth_min, + depth_max=self.video_encoder.depth_max, + shift=self.video_encoder.shift, + use_log=self.video_encoder.use_log, + video_backend=self.video_encoder.video_backend, + ) video_frame.pts = frame_count video_frame.time_base = Fraction(1, self.fps) packet = output_stream.encode(video_frame) @@ -692,13 +705,10 @@ class StreamingVideoEncoder: video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4" encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder - codec_options = encoder.get_codec_options(self._encoder_threads, as_strings=True) encoder_thread = _CameraEncoderThread( video_path=video_path, fps=self.fps, - vcodec=encoder.vcodec, - pix_fmt=encoder.pix_fmt, - codec_options=codec_options, + video_encoder=encoder, frame_queue=frame_queue, result_queue=result_queue, stop_event=stop_event, @@ -908,13 +918,13 @@ def get_audio_info(video_path: Path | str) -> dict: def get_video_info( video_path: Path | str, - camera_encoder: VideoEncoderConfig | None = None, + video_encoder: VideoEncoderConfig | None = None, ) -> dict: """Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``. Args: video_path: Path to the encoded video file to probe. - camera_encoder: If provided, record the exact encoder settings used to encode this + video_encoder: If provided, record the exact encoder settings used to encode this video. Stream-derived values take precedence — encoder fields are only written for keys not already populated from the video file itself. """ @@ -949,12 +959,13 @@ def get_video_info( video_info.update(**get_audio_info(video_path)) # Add additional encoder configuration if provided - if camera_encoder is not None: - for field_name, field_value in asdict(camera_encoder).items(): + if video_encoder is not None: + for field_name, field_value in asdict(video_encoder).items(): # vcodec is already populated from the video stream if field_name == "vcodec": continue video_info.setdefault(f"video.{field_name}", field_value) + video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig) return video_info diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py index b69f24254..6af2bc797 100644 --- a/tests/datasets/test_streaming_video_encoder.py +++ b/tests/datasets/test_streaming_video_encoder.py @@ -61,9 +61,7 @@ class TestCameraEncoderThread: encoder_thread = _CameraEncoderThread( video_path=video_path, fps=fps, - vcodec=enc_cfg.vcodec, - pix_fmt=enc_cfg.pix_fmt, - codec_options=enc_cfg.get_codec_options(as_strings=True), + video_encoder=enc_cfg, frame_queue=frame_queue, result_queue=result_queue, stop_event=stop_event, @@ -112,9 +110,7 @@ class TestCameraEncoderThread: encoder_thread = _CameraEncoderThread( video_path=video_path, fps=fps, - vcodec=enc_cfg.vcodec, - pix_fmt=enc_cfg.pix_fmt, - codec_options=enc_cfg.get_codec_options(as_strings=True), + video_encoder=enc_cfg, frame_queue=frame_queue, result_queue=result_queue, stop_event=stop_event, @@ -146,9 +142,7 @@ class TestCameraEncoderThread: encoder_thread = _CameraEncoderThread( video_path=video_path, fps=fps, - vcodec=enc_cfg.vcodec, - pix_fmt=enc_cfg.pix_fmt, - codec_options=enc_cfg.get_codec_options(as_strings=True), + video_encoder=enc_cfg, frame_queue=frame_queue, result_queue=result_queue, stop_event=stop_event,