feat(tools): adding depth support in LeRobotDataset edition tools

This commit is contained in:
CarolinePascal
2026-05-26 16:59:43 +02:00
parent a1ec48d3a9
commit 05d2a6062d
4 changed files with 81 additions and 27 deletions
+3
View File
@@ -39,6 +39,7 @@ from .video import (
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
encoder_config_from_video_info,
)
__all__ = [
@@ -63,6 +64,8 @@ __all__ = [
# Defaults
"camera_encoder_defaults",
"depth_encoder_defaults",
# Factories
"encoder_config_from_video_info",
# Constants
"VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS",
+21
View File
@@ -313,3 +313,24 @@ class DepthEncoderConfig(VideoEncoderConfig):
def depth_encoder_defaults() -> DepthEncoderConfig:
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
return DepthEncoderConfig()
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
"""Build the appropriate encoder config from a feature's ``info`` block.
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
as a depth map and to :class:`VideoEncoderConfig`
otherwise.
Args:
video_info: A feature's ``info`` dict as persisted in ``info.json``,
or ``None`` (treated as an empty dict).
Returns:
A :class:`DepthEncoderConfig` for depth features, otherwise a
:class:`VideoEncoderConfig`.
"""
video_info = video_info or {}
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
return cls.from_video_info(video_info)
+6 -1
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Callable
from collections.abc import Callable, Iterable
from pathlib import Path
import numpy as np
@@ -600,6 +600,7 @@ class LeRobotDatasetMetadata:
self,
video_key: str | None = None,
video_encoder: VideoEncoderConfig | None = None,
preserve_keys: Iterable[str] | None = None,
) -> None:
"""Populate per-feature video info in ``info.json``.
@@ -613,11 +614,14 @@ class LeRobotDatasetMetadata:
videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`).
preserve_keys: Optional iterable of ``info`` keys whose existing
values must be kept as-is.
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
preserve_set = set(preserve_keys or ())
for key in video_keys:
existing = self.features[key].get("info") or {}
# Skip only if real video info has already been written. The ``is_depth_map`` entry (created at feature creation) is not blocking.
@@ -625,6 +629,7 @@ class LeRobotDatasetMetadata:
continue
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
new_info = get_video_info(video_path, video_encoder=video_encoder)
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
self.info.features[key]["info"] = {**existing, **new_info}
def update_chunk_settings(
+51 -26
View File
@@ -36,7 +36,8 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, encoder_config_from_video_info, depth_encoder_defaults
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
from lerobot.utils.utils import flatten_dict
@@ -732,7 +733,7 @@ def _copy_and_reindex_videos(
for video_key in src_dataset.meta.video_keys:
logging.info(f"Processing videos for {video_key}")
camera_encoder = VideoEncoderConfig.from_video_info(
video_encoder = encoder_config_from_video_info(
src_dataset.meta.info.features.get(video_key, {}).get("info")
)
@@ -816,7 +817,7 @@ def _copy_and_reindex_videos(
dst_video_path,
episodes_to_keep_ranges,
src_dataset.meta.fps,
camera_encoder,
video_encoder,
)
cumulative_ts = 0.0
@@ -1196,7 +1197,10 @@ def _save_batch_episodes_images(
i, item = i_item_tuple
img = item[img_key_param]
# Use global frame index for naming
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
if img_key_param in dataset.meta.depth_keys:
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.tiff"), compression="raw")
else:
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
return i
episode_durations = []
@@ -1287,7 +1291,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int],
temp_dir: Path,
fps: int,
camera_encoder: VideoEncoderConfig,
video_encoder: VideoEncoderConfig,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
@@ -1301,7 +1305,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
camera_encoder: Video encoder settings used for calibration encoding.
video_encoder: Video encoder settings used for calibration encoding.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
@@ -1337,7 +1341,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
video_encoder=camera_encoder,
video_encoder=video_encoder,
overwrite=True,
)
@@ -1610,6 +1614,7 @@ def recompute_stats(
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
#TODO: enable image and video stats re-computation
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
@@ -1656,6 +1661,7 @@ def convert_image_to_video_dataset(
output_dir: Path | None = None,
repo_id: str | None = None,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: VideoEncoderConfig | None = None,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
@@ -1682,6 +1688,8 @@ def convert_image_to_video_dataset(
"""
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
if depth_encoder is None:
depth_encoder = depth_encoder_defaults()
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
@@ -1707,8 +1715,7 @@ def convert_image_to_video_dataset(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}"
)
# Create new features dict, converting image features to video features
@@ -1771,6 +1778,8 @@ def convert_image_to_video_dataset(
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
for img_key in tqdm(img_keys, desc="Processing cameras"):
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
# Estimate size per frame by encoding a small calibration sample
# This provides accurate compression ratio for the specific codec parameters
size_per_frame_mb = _estimate_frame_size_via_calibration(
@@ -1779,7 +1788,7 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
camera_encoder=camera_encoder,
video_encoder=target_encoder,
)
logging.info(f"Processing camera: {img_key}")
@@ -1821,7 +1830,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
video_encoder=camera_encoder,
video_encoder=target_encoder,
overwrite=True,
)
@@ -1909,7 +1918,8 @@ def _reencode_video_worker(args: tuple) -> Path:
def reencode_dataset(
dataset: LeRobotDataset,
camera_encoder: VideoEncoderConfig,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
num_workers: int | None = None,
) -> LeRobotDataset:
@@ -1920,8 +1930,11 @@ def reencode_dataset(
Args:
dataset: An existing :class:`LeRobotDataset` whose videos will be
re-encoded.
camera_encoder: Target encoder configuration applied to every video
file.
camera_encoder: Target encoder configuration applied to every RGB video
file. If ``None``, re-encoding is skipped for RGB videos.
depth_encoder: Target encoder configuration applied to every depth video
file. If ``None``, re-encoding is skipped for depth videos.
Quantization parameters will not override the ones in the current dataset.
encoder_threads: Per-encoder thread count forwarded to
:func:`reencode_video`. ``None`` lets the codec decide.
num_workers: Number of parallel processes. ``None`` or ``0`` means
@@ -1933,23 +1946,31 @@ def reencode_dataset(
on disk.
"""
meta = dataset.meta
video_paths_list = []
video_keys_encoders_dict = {}
video_keys_paths_dict = {}
if camera_encoder is None and depth_encoder is None:
raise ValueError("Either camera_encoder or depth_encoder must be provided")
# Only re-encode if the videos are not already encoded with the given video encoding parameters
for video_key in meta.video_keys:
current_info = meta.info.features[video_key].get("info", {})
current_encoder = VideoEncoderConfig.from_video_info(current_info)
if current_encoder != camera_encoder:
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
current_encoder = encoder_config_from_video_info(current_info)
target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
if target_encoder is None:
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
elif current_encoder != target_encoder:
video_keys_paths_dict[video_key] = (meta.root / VIDEO_DIR / video_key).rglob("*.mp4")
video_keys_encoders_dict[video_key] = target_encoder
else:
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
if len(video_paths_list) == 0:
if len(video_keys_paths_dict) == 0:
logging.warning("Dataset has no videos to re-encode.")
return dataset
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
worker_args = [(path, encoder, encoder_threads) for video_key, encoder in video_keys_encoders_dict.items() for path in video_keys_paths_dict[video_key]]
if num_workers and num_workers > 1:
with ProcessPoolExecutor(max_workers=num_workers) as pool:
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
@@ -1963,10 +1984,14 @@ def reencode_dataset(
for args in tqdm(worker_args, desc="Re-encoding videos"):
_reencode_video_worker(args)
# Refresh video info in metadata for every video key.
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(0, vid_key)
meta.info.features[vid_key]["info"] = get_video_info(video_path, video_encoder=camera_encoder)
# Refresh video info in metadata for every video key. For depth videos, preserve
# ``is_depth_map`` and the depth quantization parameters.
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
for video_key, encoder in video_keys_encoders_dict.items():
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else None
meta.update_video_info(
video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys
)
write_info(meta.info, meta.root)
logging.info("Dataset metadata updated.")