mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
feat(depth): wire StreamingVideoEncoder + writer to depth encoder
This commit is contained in:
@@ -31,7 +31,12 @@ import PIL.Image
|
|||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, depth_encoder_defaults
|
from lerobot.configs import (
|
||||||
|
DepthEncoderConfig,
|
||||||
|
VideoEncoderConfig,
|
||||||
|
camera_encoder_defaults,
|
||||||
|
depth_encoder_defaults,
|
||||||
|
)
|
||||||
|
|
||||||
from .compute_stats import compute_episode_stats
|
from .compute_stats import compute_episode_stats
|
||||||
from .dataset_metadata import LeRobotDatasetMetadata
|
from .dataset_metadata import LeRobotDatasetMetadata
|
||||||
@@ -509,7 +514,12 @@ class DatasetWriter:
|
|||||||
|
|
||||||
# Update video info (only needed when first episode is encoded)
|
# Update video info (only needed when first episode is encoded)
|
||||||
if episode_index == 0:
|
if episode_index == 0:
|
||||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
self._meta.update_video_info(
|
||||||
|
video_key,
|
||||||
|
video_encoder=self._depth_encoder
|
||||||
|
if video_key in self._meta.depth_keys
|
||||||
|
else self._camera_encoder,
|
||||||
|
)
|
||||||
write_info(self._meta.info, self._meta.root)
|
write_info(self._meta.info, self._meta.root)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
|
|||||||
@@ -330,10 +330,20 @@ def validate_feature_image_or_video(
|
|||||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||||
error_message = ""
|
error_message = ""
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
actual_shape = value.shape
|
actual_shape = tuple(value.shape)
|
||||||
c, h, w = expected_shape
|
expected = tuple(expected_shape)
|
||||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
if len(expected) == 2:
|
||||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
|
||||||
|
h, w = expected
|
||||||
|
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
|
||||||
|
if not valid:
|
||||||
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
|
||||||
|
elif len(expected) == 3:
|
||||||
|
c, h, w = expected
|
||||||
|
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||||
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
|
||||||
elif isinstance(value, PILImage.Image):
|
elif isinstance(value, PILImage.Image):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import torch.utils
|
|||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
|
||||||
from lerobot.configs import VideoEncoderConfig, DepthEncoderConfig
|
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||||
|
|
||||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||||
|
|||||||
@@ -37,13 +37,15 @@ from datasets.features.features import register_feature
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from lerobot.configs import (
|
from lerobot.configs import (
|
||||||
VideoEncoderConfig,
|
|
||||||
DepthEncoderConfig,
|
DepthEncoderConfig,
|
||||||
|
VideoEncoderConfig,
|
||||||
camera_encoder_defaults,
|
camera_encoder_defaults,
|
||||||
depth_encoder_defaults,
|
depth_encoder_defaults,
|
||||||
)
|
)
|
||||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||||
|
|
||||||
|
from .depth_utils import quantize_depth
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -521,9 +523,7 @@ class _CameraEncoderThread(threading.Thread):
|
|||||||
self,
|
self,
|
||||||
video_path: Path,
|
video_path: Path,
|
||||||
fps: int,
|
fps: int,
|
||||||
vcodec: str,
|
video_encoder: VideoEncoderConfig,
|
||||||
pix_fmt: str,
|
|
||||||
codec_options: dict[str, str],
|
|
||||||
frame_queue: queue.Queue,
|
frame_queue: queue.Queue,
|
||||||
result_queue: queue.Queue,
|
result_queue: queue.Queue,
|
||||||
stop_event: threading.Event,
|
stop_event: threading.Event,
|
||||||
@@ -531,9 +531,8 @@ class _CameraEncoderThread(threading.Thread):
|
|||||||
super().__init__(daemon=True)
|
super().__init__(daemon=True)
|
||||||
self.video_path = video_path
|
self.video_path = video_path
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.vcodec = vcodec
|
self.video_encoder = video_encoder
|
||||||
self.pix_fmt = pix_fmt
|
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||||
self.codec_options = codec_options
|
|
||||||
self.frame_queue = frame_queue
|
self.frame_queue = frame_queue
|
||||||
self.result_queue = result_queue
|
self.result_queue = result_queue
|
||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
@@ -561,12 +560,12 @@ class _CameraEncoderThread(threading.Thread):
|
|||||||
# Sentinel: flush and close
|
# Sentinel: flush and close
|
||||||
break
|
break
|
||||||
|
|
||||||
# Ensure HWC uint8 numpy array
|
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||||
if isinstance(frame_data, np.ndarray):
|
if isinstance(frame_data, np.ndarray):
|
||||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||||
# CHW -> HWC
|
# CHW -> HWC
|
||||||
frame_data = frame_data.transpose(1, 2, 0)
|
frame_data = frame_data.transpose(1, 2, 0)
|
||||||
if frame_data.dtype != np.uint8:
|
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||||
frame_data = (frame_data * 255).astype(np.uint8)
|
frame_data = (frame_data * 255).astype(np.uint8)
|
||||||
|
|
||||||
# Open container on first frame (to get width/height)
|
# Open container on first frame (to get width/height)
|
||||||
@@ -574,15 +573,29 @@ class _CameraEncoderThread(threading.Thread):
|
|||||||
height, width = frame_data.shape[:2]
|
height, width = frame_data.shape[:2]
|
||||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
container = av.open(str(self.video_path), "w")
|
container = av.open(str(self.video_path), "w")
|
||||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
output_stream = container.add_stream(
|
||||||
output_stream.pix_fmt = self.pix_fmt
|
self.video_encoder.vcodec,
|
||||||
|
self.fps,
|
||||||
|
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||||
|
)
|
||||||
|
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||||
output_stream.width = width
|
output_stream.width = width
|
||||||
output_stream.height = height
|
output_stream.height = height
|
||||||
output_stream.time_base = Fraction(1, self.fps)
|
output_stream.time_base = Fraction(1, self.fps)
|
||||||
|
|
||||||
# Encode frame with explicit timestamps
|
# Encode frame with explicit timestamps
|
||||||
pil_img = Image.fromarray(frame_data)
|
if not self.is_depth:
|
||||||
video_frame = av.VideoFrame.from_image(pil_img)
|
pil_img = Image.fromarray(frame_data)
|
||||||
|
video_frame = av.VideoFrame.from_image(pil_img)
|
||||||
|
else:
|
||||||
|
video_frame = quantize_depth(
|
||||||
|
frame_data,
|
||||||
|
depth_min=self.video_encoder.depth_min,
|
||||||
|
depth_max=self.video_encoder.depth_max,
|
||||||
|
shift=self.video_encoder.shift,
|
||||||
|
use_log=self.video_encoder.use_log,
|
||||||
|
video_backend=self.video_encoder.video_backend,
|
||||||
|
)
|
||||||
video_frame.pts = frame_count
|
video_frame.pts = frame_count
|
||||||
video_frame.time_base = Fraction(1, self.fps)
|
video_frame.time_base = Fraction(1, self.fps)
|
||||||
packet = output_stream.encode(video_frame)
|
packet = output_stream.encode(video_frame)
|
||||||
@@ -692,13 +705,10 @@ class StreamingVideoEncoder:
|
|||||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||||
|
|
||||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||||
codec_options = encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
|
||||||
encoder_thread = _CameraEncoderThread(
|
encoder_thread = _CameraEncoderThread(
|
||||||
video_path=video_path,
|
video_path=video_path,
|
||||||
fps=self.fps,
|
fps=self.fps,
|
||||||
vcodec=encoder.vcodec,
|
video_encoder=encoder,
|
||||||
pix_fmt=encoder.pix_fmt,
|
|
||||||
codec_options=codec_options,
|
|
||||||
frame_queue=frame_queue,
|
frame_queue=frame_queue,
|
||||||
result_queue=result_queue,
|
result_queue=result_queue,
|
||||||
stop_event=stop_event,
|
stop_event=stop_event,
|
||||||
@@ -908,13 +918,13 @@ def get_audio_info(video_path: Path | str) -> dict:
|
|||||||
|
|
||||||
def get_video_info(
|
def get_video_info(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
camera_encoder: VideoEncoderConfig | None = None,
|
video_encoder: VideoEncoderConfig | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: Path to the encoded video file to probe.
|
video_path: Path to the encoded video file to probe.
|
||||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||||
not already populated from the video file itself.
|
not already populated from the video file itself.
|
||||||
"""
|
"""
|
||||||
@@ -949,12 +959,13 @@ def get_video_info(
|
|||||||
video_info.update(**get_audio_info(video_path))
|
video_info.update(**get_audio_info(video_path))
|
||||||
|
|
||||||
# Add additional encoder configuration if provided
|
# Add additional encoder configuration if provided
|
||||||
if camera_encoder is not None:
|
if video_encoder is not None:
|
||||||
for field_name, field_value in asdict(camera_encoder).items():
|
for field_name, field_value in asdict(video_encoder).items():
|
||||||
# vcodec is already populated from the video stream
|
# vcodec is already populated from the video stream
|
||||||
if field_name == "vcodec":
|
if field_name == "vcodec":
|
||||||
continue
|
continue
|
||||||
video_info.setdefault(f"video.{field_name}", field_value)
|
video_info.setdefault(f"video.{field_name}", field_value)
|
||||||
|
video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||||
|
|
||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|||||||
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
|
|||||||
encoder_thread = _CameraEncoderThread(
|
encoder_thread = _CameraEncoderThread(
|
||||||
video_path=video_path,
|
video_path=video_path,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
vcodec=enc_cfg.vcodec,
|
video_encoder=enc_cfg,
|
||||||
pix_fmt=enc_cfg.pix_fmt,
|
|
||||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
|
||||||
frame_queue=frame_queue,
|
frame_queue=frame_queue,
|
||||||
result_queue=result_queue,
|
result_queue=result_queue,
|
||||||
stop_event=stop_event,
|
stop_event=stop_event,
|
||||||
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
|
|||||||
encoder_thread = _CameraEncoderThread(
|
encoder_thread = _CameraEncoderThread(
|
||||||
video_path=video_path,
|
video_path=video_path,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
vcodec=enc_cfg.vcodec,
|
video_encoder=enc_cfg,
|
||||||
pix_fmt=enc_cfg.pix_fmt,
|
|
||||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
|
||||||
frame_queue=frame_queue,
|
frame_queue=frame_queue,
|
||||||
result_queue=result_queue,
|
result_queue=result_queue,
|
||||||
stop_event=stop_event,
|
stop_event=stop_event,
|
||||||
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
|
|||||||
encoder_thread = _CameraEncoderThread(
|
encoder_thread = _CameraEncoderThread(
|
||||||
video_path=video_path,
|
video_path=video_path,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
vcodec=enc_cfg.vcodec,
|
video_encoder=enc_cfg,
|
||||||
pix_fmt=enc_cfg.pix_fmt,
|
|
||||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
|
||||||
frame_queue=frame_queue,
|
frame_queue=frame_queue,
|
||||||
result_queue=result_queue,
|
result_queue=result_queue,
|
||||||
stop_event=stop_event,
|
stop_event=stop_event,
|
||||||
|
|||||||
Reference in New Issue
Block a user