mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-20 17:57:14 +00:00
Compare commits
62 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 45f49b6600 | |||
| c56c6991d1 | |||
| e8f6c32623 | |||
| cc17582f71 | |||
| 529f48b540 | |||
| c128027415 | |||
| c0db93f4a0 | |||
| 1f66e6f5e4 | |||
| 7a67ce9e50 | |||
| 4f26878d8f | |||
| a694e32774 | |||
| 655338abf3 | |||
| 364d4de96f | |||
| 41a942658b | |||
| 030d9a279a | |||
| 30479cf277 | |||
| cb6b2d77bd | |||
| 76f79f3955 | |||
| 9e994baa04 | |||
| 6fd911ebb9 | |||
| f712698272 | |||
| c2416ecbcb | |||
| 6aa50cc1e5 | |||
| e17adce3ba | |||
| f7010ff66c | |||
| f7ee453de7 | |||
| ca7168f413 | |||
| ec6264d768 | |||
| d93a58a8b8 | |||
| 92497dfcd8 | |||
| 263108d6c1 | |||
| a925d20ce4 | |||
| 1f024ea3bf | |||
| d5f67cc7fc | |||
| 9ab8c98494 | |||
| a561183442 | |||
| 305b8d64b2 | |||
| 0a624a5cf5 | |||
| d044ead377 | |||
| e425fcb61a | |||
| f08a9aea71 | |||
| 7d97b55cc4 | |||
| edbd8c6f82 | |||
| 615954b80b | |||
| 1c0fdfdb4b | |||
| 1c3ebd475f | |||
| c655814788 | |||
| a72ab14f89 | |||
| 882074d707 | |||
| 4ae2f9f375 | |||
| 26099b6e03 | |||
| 6b395dfb24 | |||
| 1cbabfe9a4 | |||
| 4744f4b913 | |||
| 9568e68b28 | |||
| 10941c31f6 | |||
| a6882a048a | |||
| eb2b7d6dc3 | |||
| f7f7b8c7f8 | |||
| d58a324da4 | |||
| 287c823f13 | |||
| 58ccc01508 |
@@ -157,6 +157,14 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Working with depth
|
||||
|
||||
The Intel RealSense and Reachy 2 cameras can capture both color and depth in lockstep. Calling `read()` returns the **color** frame as `(H, W, 3)` `uint8`. Calling `read_depth()` returns the **depth map** as `(H, W, 1)` `uint16`, where each pixel value is the distance from the sensor expressed in **millimetres**. A pixel value of `0` typically means "no measurement available" (out-of-range, occluded, or low-confidence).
|
||||
|
||||
During recording, the control loop peeks the freshest buffered frames non-blockingly via `read_latest()` (color) and `read_latest_depth()` (depth), adding the depth map as a sibling feature (e.g. `front_depth` next to `front`).
|
||||
|
||||
For how depth streams are stored and encoded when recording a dataset, see the [Depth streams](./video_encoding_parameters#depth-streams) section of the video encoding guide.
|
||||
|
||||
## Use your phone's camera
|
||||
|
||||
<hfoptions id="use phone">
|
||||
|
||||
@@ -11,8 +11,9 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage (RGB and depth cameras are encoded with separate encoders)
|
||||
7. **Re-encode Videos** - Re-encode an existing video dataset's RGB and/or depth streams with new encoder settings
|
||||
8. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -122,6 +123,15 @@ lerobot-edit-dataset \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert a dataset that includes depth maps, customizing the depth encoder
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.depth_encoder.depth_min 0.01 \
|
||||
--operation.depth_encoder.depth_max 10.0 \
|
||||
--operation.depth_encoder.use_log true
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -147,11 +157,42 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `camera_encoder`: Video encoder settings applied to RGB cameras — all sub-fields accessible via `--operation.camera_encoder.<field>`. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `depth_encoder`: Video encoder settings applied to depth-map cameras (e.g. from an Intel RealSense). In addition to the standard encoder fields it exposes the depth quantization knobs (`depth_min`, `depth_max`, `shift`, `use_log`), accessible via `--operation.depth_encoder.<field>`. These quantization settings are persisted to the dataset metadata so depth can be dequantized back to physical units on load. See the [Depth streams](./video_encoding_parameters#depth-streams) section for details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). Depth-map cameras are detected automatically and routed to the `depth_encoder`, while RGB cameras use the `camera_encoder`. All episodes, stats, and tasks are preserved.
|
||||
|
||||
#### Re-encode Videos
|
||||
|
||||
Re-encode the videos of an existing video dataset with different encoder settings, without going back to raw frames. RGB videos use the `camera_encoder` and depth videos use the `depth_encoder`. Provide only the encoder(s) you want to re-encode; the other stream type is left untouched.
|
||||
|
||||
```bash
|
||||
# Re-encode all RGB videos with new settings (saves to lerobot/pusht_reencoded by default)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec h264 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.crf 23
|
||||
|
||||
# Re-encode both RGB and depth videos in a dataset with depth maps
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_depth \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec libx264 \
|
||||
--operation.depth_encoder.vcodec ffv1
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `camera_encoder`: Encoder settings applied to every RGB video. Omit to skip re-encoding RGB videos.
|
||||
- `depth_encoder`: Encoder settings applied to every depth video. Omit to skip re-encoding depth videos.
|
||||
- `num_workers`: Number of parallel workers for processing.
|
||||
|
||||
> [!NOTE]
|
||||
> When re-encoding depth videos, the existing depth quantization parameters (`depth_min`, `depth_max`, `shift`, `use_log`) and the `is_depth_map` flag are **preserved** — re-encoding only changes the codec/quality of the stored stream, not how depth is dequantized on load.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
|
||||
@@ -65,6 +65,76 @@ All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
---
|
||||
|
||||
## Depth streams
|
||||
|
||||
Depth maps (Intel RealSense, Reachy 2) are stored as their **own video streams** alongside the RGB streams. Raw depth (`uint16` millimetres or `float32` metres) can't survive an 8-bit codec, so LeRobot **quantizes** each map to a 12-bit code (`[0, 4095]`) — logarithmically by default, to match the `1/depth` error profile of depth sensors — then packs it into a high-bit-depth pixel format (`gray12le`) and encodes it with a 12-bit codec.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
A["Raw depth (uint16 mm / float32 m)"] --> B["Clip to depth_min, depth_max"]
|
||||
B --> C["Quantize to 12-bit code 0–4095 (log or linear)"]
|
||||
C --> D["Pack into gray12le"]
|
||||
D --> E["Encode video (hevc Main 12)"]
|
||||
E --> F[("MP4 + metadata: depth_min/max, shift, use_log")]
|
||||
F -. "load time (depth_output_unit)" .-> G["Dequantize to mm or m"]
|
||||
|
||||
classDef input fill:#e3f2fd,stroke:#1565c0,color:#0d47a1;
|
||||
classDef encode fill:#ede7f6,stroke:#5e35b1,color:#311b92;
|
||||
classDef store fill:#fff8e1,stroke:#f9a825,color:#e65100;
|
||||
classDef load fill:#e8f5e9,stroke:#2e7d32,color:#1b5e20;
|
||||
|
||||
class A input;
|
||||
class B,C,D,E encode;
|
||||
class F store;
|
||||
class G load;
|
||||
```
|
||||
|
||||
Configure the depth pipeline through a parallel **`depth_encoder`** block (`DepthEncoderConfig`). It inherits every `VideoEncoderConfig` field (`vcodec`, `pix_fmt`, `crf`, …) and adds four quantizer knobs, set via `--dataset.depth_encoder.<field>`:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
... \
|
||||
--dataset.depth_encoder.vcodec=hevc \
|
||||
--dataset.depth_encoder.depth_min=0.05 \
|
||||
--dataset.depth_encoder.depth_max=5.0 \
|
||||
--dataset.depth_encoder.use_log=true
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ----------- | ------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"hevc"` | Defaults to HEVC Main 12 (a 12-bit-capable codec). `ffv1` is a lossless alternative. |
|
||||
| `pix_fmt` | `str` | `"gray12le"` | Single-channel 12-bit pixel format used to carry the quantized codes. |
|
||||
| `depth_min` | `float` | `0.01` | Depth in metres mapped to quantum `0`. Values below are clipped on decode. |
|
||||
| `depth_max` | `float` | `10.0` | Depth in metres mapped to quantum `4095`. Values above are clipped on decode. |
|
||||
| `shift` | `float` | `3.5` | Pre-log offset (metres) used in logarithmic quantization for numerical stability near zero. Must satisfy `depth_min + shift > 0`. |
|
||||
| `use_log` | `bool` | `True` | If `true`, quantize in log-space (recommended for typical depth sensors). Set to `false` for uniform/linear quantization. |
|
||||
|
||||
> [!TIP]
|
||||
> `depth_min`, `depth_max`, and `shift` are always interpreted in **metres**, regardless of the input depth's unit. Inputs are auto-detected: integer arrays (e.g. `uint16` millimetres straight from a RealSense) are treated as millimetres, floating arrays as metres.
|
||||
> Pick `depth_min` / `depth_max` to bracket the actual working range of your sensor — quanta outside that range saturate, which can crush detail at the boundaries.
|
||||
|
||||
Depth features are flagged with `"is_depth_map": true` in `meta/info.json`, and their quantizer settings (`video.depth_min`, `video.depth_max`, `video.shift`, `video.use_log`) are persisted — which is what lets depth be **dequantized back to physical units** on load.
|
||||
|
||||
### Output unit at load time
|
||||
|
||||
`depth_encoder` is a **record-time** concern. The unit that depth maps are dequantized to on _load_ (e.g. during training) is set separately by the read-time flag `--dataset.depth_output_unit`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.depth_output_unit=m \
|
||||
--policy.type=act
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------- | ----- | ------- | -------------------------------------------------------------------------------------------- |
|
||||
| `depth_output_unit` | `str` | `"mm"` | Physical unit depth maps are dequantized to on load: `"mm"` (millimetres) or `"m"` (metres). |
|
||||
|
||||
> [!TIP]
|
||||
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
@@ -82,7 +152,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
@@ -97,7 +167,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -355,8 +355,6 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "feat/hffs-cache-cdn-range-reads" }
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
@@ -423,7 +421,6 @@ exclude_dirs = [
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
[tool.typos]
|
||||
default.extend-words = { trak = "trak" }
|
||||
default.extend-ignore-re = [
|
||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
||||
|
||||
@@ -1,903 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import resource
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeByteCache,
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
||||
FULL_SIDECAR_NAME = "molmoact2-full.npz"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--range-backend",
|
||||
choices=("fsspec", "native-http"),
|
||||
default="fsspec",
|
||||
help="Range reader used by indexed/full episode-pool fetch tracks.",
|
||||
)
|
||||
parser.add_argument("--num-episodes", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--manifest-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
||||
)
|
||||
parser.add_argument("--pool-size", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--native-http-connections",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-retries",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Retries per native HTTP range request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="Timeout in seconds for native HTTP requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-interval",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="Print episode-pool fill progress every N seconds. Set 0 to disable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-decode",
|
||||
action="store_true",
|
||||
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
|
||||
)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||
parser.add_argument(
|
||||
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(total, requested)
|
||||
if pool_size > upper:
|
||||
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||
return rng.sample(range(upper), pool_size)
|
||||
|
||||
|
||||
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span = manifest.lookup(ep, camera_key)
|
||||
lo = span.first_pts
|
||||
hi = max(span.last_pts, lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _timestamps_from_meta(
|
||||
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
||||
) -> dict[tuple[int, str], list[float]]:
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
row = meta.episodes[ep]
|
||||
for camera_key in meta.video_keys:
|
||||
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
||||
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
total = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
total += manifest.lookup(ep, camera_key).mdat_length
|
||||
return total
|
||||
|
||||
|
||||
def _decode_all(
|
||||
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
items = list(timestamps.items())
|
||||
if decode_workers <= 1:
|
||||
for (ep, camera_key), ts in items:
|
||||
cache.get_frames(ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _fill_cache(
|
||||
cache: EpisodeByteCache, episodes: Sequence[int], *, progress_interval: float = 10.0
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
cache.submit_prefetch(ep)
|
||||
last_progress = start
|
||||
for idx, ep in enumerate(episodes, start=1):
|
||||
cache.ensure_ready(ep)
|
||||
now = time.perf_counter()
|
||||
if progress_interval > 0 and now - last_progress >= progress_interval:
|
||||
timings = cache.timing_summary()
|
||||
byte_count = timings.get("range_bytes", 0.0)
|
||||
elapsed = max(now - start, 1e-9)
|
||||
jobs = timings.get("jobs", 0.0)
|
||||
total_jobs = len(episodes) * len(cache.manifest.video_keys)
|
||||
_log(
|
||||
"fill_progress: "
|
||||
f"episodes_ready={idx}/{len(episodes)} "
|
||||
f"camera_jobs={jobs:.0f}/{total_jobs} "
|
||||
f"fetched={byte_count / 1024**3:.2f} GiB "
|
||||
f"fetch={byte_count / elapsed / 1024**2:.1f} MiB/s "
|
||||
f"elapsed={_format_duration(elapsed)}"
|
||||
)
|
||||
last_progress = now
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
||||
if elapsed_s <= 0:
|
||||
return float("inf")
|
||||
return len(episodes) * frames_per_episode / elapsed_s
|
||||
|
||||
|
||||
def _log(message: str) -> None:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
if seconds < 3600:
|
||||
return f"{seconds / 60:.1f}m"
|
||||
return f"{seconds / 3600:.1f}h"
|
||||
|
||||
|
||||
def _current_rss_mib() -> float | None:
|
||||
status_path = Path("/proc/self/status")
|
||||
if not status_path.exists():
|
||||
return None
|
||||
for line in status_path.read_text().splitlines():
|
||||
if line.startswith("VmRSS:"):
|
||||
return float(line.split()[1]) / 1024
|
||||
return None
|
||||
|
||||
|
||||
def _peak_rss_mib() -> float:
|
||||
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
# Linux reports KiB; macOS reports bytes.
|
||||
if rss > 10**8:
|
||||
return rss / 1024**2
|
||||
return rss / 1024
|
||||
|
||||
|
||||
def _memory_snapshot() -> dict[str, float | None]:
|
||||
return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()}
|
||||
|
||||
|
||||
def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None:
|
||||
start_rss = start["rss_mib"]
|
||||
end_rss = end["rss_mib"]
|
||||
delta = None if start_rss is None or end_rss is None else end_rss - start_rss
|
||||
print()
|
||||
print("| Memory | MiB |")
|
||||
print("|---|---:|")
|
||||
if start_rss is not None:
|
||||
print(f"| rss start | {start_rss:.1f} |")
|
||||
if end_rss is not None:
|
||||
print(f"| rss end | {end_rss:.1f} |")
|
||||
if delta is not None:
|
||||
print(f"| rss delta | {delta:.1f} |")
|
||||
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
|
||||
|
||||
|
||||
def _root_join(data_root: str, relative_path: str) -> str:
|
||||
if data_root.startswith("hf://"):
|
||||
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||
return str(Path(data_root) / relative_path)
|
||||
|
||||
|
||||
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
||||
_ = manifest_episode_count
|
||||
local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
if _valid_sidecar(local):
|
||||
return local
|
||||
if local.exists():
|
||||
print(f"mp4_sidecar_invalid_local: {local}")
|
||||
local.unlink()
|
||||
remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}"
|
||||
remote = _root_join(data_root, remote_relative)
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
fs = fsspec.filesystem(protocol)
|
||||
if not fs.exists(remote):
|
||||
return None
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
||||
if data_root.startswith("hf://"):
|
||||
_download_sidecar_native_http(data_root, remote_relative, local)
|
||||
else:
|
||||
fs.get(remote, str(local))
|
||||
return local
|
||||
|
||||
|
||||
def _valid_sidecar(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
with np.load(path, allow_pickle=False) as data:
|
||||
return "manifest_json" in data
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
||||
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
||||
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||
try:
|
||||
size = fetcher.info_size(relative_path)
|
||||
chunk_size = 16 * 1024 * 1024
|
||||
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
||||
with tmp.open("wb") as out_file:
|
||||
out_file.truncate(size)
|
||||
|
||||
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
||||
offset, length = offset_length
|
||||
return offset, fetcher.read_range(relative_path, offset, length)
|
||||
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
futures = [pool.submit(read_chunk, item) for item in ranges]
|
||||
with tmp.open("r+b") as rw_file:
|
||||
for future in futures:
|
||||
offset, data = future.result()
|
||||
rw_file.seek(offset)
|
||||
rw_file.write(data)
|
||||
done += len(data)
|
||||
elapsed = max(time.perf_counter() - start, 1e-9)
|
||||
print(
|
||||
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
||||
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
||||
flush=True,
|
||||
)
|
||||
tmp.replace(local)
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
|
||||
class EpisodeParquetReader:
|
||||
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
||||
self.meta = meta
|
||||
self.data_root = data_root
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self._episode_row_groups = self._build_episode_row_groups()
|
||||
self._table_cache: dict[str, pa.Table] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def read_episode(self, episode_index: int) -> None:
|
||||
relative_path = str(self.meta.get_data_file_path(episode_index))
|
||||
table = self._read_table(relative_path)
|
||||
table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
|
||||
def _read_table(self, relative_path: str) -> pa.Table:
|
||||
with self._cache_lock:
|
||||
table = self._table_cache.get(relative_path)
|
||||
if table is not None:
|
||||
return table
|
||||
with self.fs.open(
|
||||
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
||||
) as f:
|
||||
table = pq.ParquetFile(f).read()
|
||||
with self._cache_lock:
|
||||
return self._table_cache.setdefault(relative_path, table)
|
||||
|
||||
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
||||
return pool.submit(self.read_episode, episode_index)
|
||||
|
||||
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
||||
start = time.perf_counter()
|
||||
if workers <= 1:
|
||||
for ep in episodes:
|
||||
self.read_episode(ep)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
def _build_episode_row_groups(self) -> dict[int, int]:
|
||||
counts: dict[tuple[int, int], int] = {}
|
||||
row_groups = {}
|
||||
for ep_idx in range(int(self.meta.total_episodes)):
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
||||
row_groups[ep_idx] = counts.get(key, 0)
|
||||
counts[key] = row_groups[ep_idx] + 1
|
||||
return row_groups
|
||||
|
||||
|
||||
def run_fetch_pool(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
native_http_connections=args.native_http_connections,
|
||||
native_http_timeout=args.native_http_timeout,
|
||||
native_http_retries=args.native_http_retries,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
elapsed = _fill_cache(cache, episodes, progress_interval=args.progress_interval)
|
||||
timings = cache.timing_summary()
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
job_count = max(timings["jobs"], 1.0)
|
||||
result = {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / elapsed,
|
||||
"episode_mb": episode_mb,
|
||||
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
||||
"jobs": timings["jobs"],
|
||||
"lookup_ms": timings["lookup_s"] * 1000 / job_count,
|
||||
"range_fetch_ms": timings["fetch_s"] * 1000 / job_count,
|
||||
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
|
||||
"store_ms": timings["store_s"] * 1000 / job_count,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
return result
|
||||
|
||||
|
||||
def run_parallel(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
||||
fetch_s = _fill_cache(cache, episodes)
|
||||
decoder_start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_decoder(ep, camera_key)
|
||||
decoder_s = time.perf_counter() - decoder_start
|
||||
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
return {
|
||||
"fetch_s": fetch_s,
|
||||
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / fetch_s,
|
||||
"parquet_s": parquet_s,
|
||||
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
||||
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_overlapped(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
prefetch_ahead: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=True,
|
||||
) as cache:
|
||||
start = time.perf_counter()
|
||||
video_wait_decode_s = 0.0
|
||||
parquet_wait_s = 0.0
|
||||
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
||||
parquet_futures = {
|
||||
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
||||
}
|
||||
for ep in episodes[:prefetch_ahead]:
|
||||
cache.submit_prefetch(ep)
|
||||
try:
|
||||
for idx, ep in enumerate(episodes):
|
||||
next_idx = idx + prefetch_ahead
|
||||
if next_idx < len(episodes):
|
||||
next_ep = episodes[next_idx]
|
||||
cache.submit_prefetch(next_ep)
|
||||
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
||||
|
||||
parquet_start = time.perf_counter()
|
||||
parquet_futures.pop(ep).result()
|
||||
parquet_wait_s += time.perf_counter() - parquet_start
|
||||
|
||||
video_start = time.perf_counter()
|
||||
cache.ensure_ready(ep)
|
||||
if decode_workers <= 1:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
||||
for camera_key in manifest.video_keys
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
video_wait_decode_s += time.perf_counter() - video_start
|
||||
finally:
|
||||
parquet_pool.shutdown(wait=True)
|
||||
elapsed = time.perf_counter() - start
|
||||
return {
|
||||
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
||||
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
||||
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
||||
"wall_s": elapsed,
|
||||
"video_wait_decode_s": video_wait_decode_s,
|
||||
"parquet_wait_s": parquet_wait_s,
|
||||
}
|
||||
|
||||
|
||||
_remote_decoder_local = threading.local()
|
||||
|
||||
|
||||
def _remote_decoder_cache() -> VideoDecoderCache:
|
||||
cache = getattr(_remote_decoder_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = VideoDecoderCache(max_size=None)
|
||||
_remote_decoder_local.cache = cache
|
||||
return cache
|
||||
|
||||
|
||||
def _decode_remote_source(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
timestamps: list[float],
|
||||
):
|
||||
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
||||
return decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
timestamps,
|
||||
tolerance_s=1.0 / float(meta.fps),
|
||||
decoder_cache=_remote_decoder_cache(),
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
|
||||
def run_remote_decoder(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
decode_workers: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> dict[str, float]:
|
||||
items = [
|
||||
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
sequential_s = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
||||
futures = [
|
||||
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
||||
for ep, camera_key, ts in items
|
||||
]
|
||||
for future in parquet_futures:
|
||||
future.result()
|
||||
for future in futures:
|
||||
future.result()
|
||||
parallel_s = time.perf_counter() - start
|
||||
|
||||
return {
|
||||
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
||||
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||
range_jobs = fetch_pool.get("range_jobs", 0.0)
|
||||
if range_jobs <= 0:
|
||||
return
|
||||
|
||||
print()
|
||||
print("| Range Read Stage | avg ms/range |")
|
||||
print("|---|---:|")
|
||||
for key, label in (
|
||||
("range_open_s", "fsspec handle open/lookup"),
|
||||
("range_seek_s", "fsspec seek"),
|
||||
("range_read_s", "fsspec read"),
|
||||
("range_resolve_s", "http URL resolve"),
|
||||
("range_header_s", "http response headers"),
|
||||
("range_first_byte_s", "http first body byte"),
|
||||
("range_body_s", "http body drain"),
|
||||
("range_chunk_gap_s", "http chunk wait"),
|
||||
("range_join_s", "join response chunks"),
|
||||
("range_retry_sleep_s", "http retry sleep"),
|
||||
):
|
||||
value = fetch_pool.get(key)
|
||||
if value is not None:
|
||||
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
|
||||
if "range_retry_attempts" in fetch_pool:
|
||||
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
|
||||
if fetch_pool.get("range_failed_requests"):
|
||||
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
|
||||
status_counts = {
|
||||
key.removeprefix("range_status_"): value
|
||||
for key, value in fetch_pool.items()
|
||||
if key.startswith("range_status_")
|
||||
}
|
||||
if status_counts:
|
||||
summary = ", ".join(f"{status}={count:.0f}" for status, count in sorted(status_counts.items()))
|
||||
print(f"| http status counts | {summary} |")
|
||||
chunks = fetch_pool.get("range_chunks", 0.0)
|
||||
if chunks > 0:
|
||||
bytes_read = fetch_pool.get("range_bytes", 0.0)
|
||||
body_s = fetch_pool.get("range_body_s", 0.0)
|
||||
print(f"| http chunks/range | {chunks / range_jobs:.1f} |")
|
||||
print(f"| http avg KiB/chunk | {bytes_read / chunks / 1024:.1f} |")
|
||||
if body_s > 0:
|
||||
print(f"| http body MiB/s | {bytes_read / body_s / 1024**2:.1f} |")
|
||||
print(f"| range reads | {range_jobs:.0f} |")
|
||||
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||
|
||||
|
||||
def run_indexed_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
label: str = "indexed",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log(f"starting_strategy: {label}")
|
||||
memory_start = _memory_snapshot()
|
||||
manifest_start = time.perf_counter()
|
||||
dataset_episode_count = int(meta.total_episodes)
|
||||
manifest_episode_count = args.manifest_episodes or dataset_episode_count
|
||||
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
_log(
|
||||
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
||||
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
||||
)
|
||||
|
||||
_log(f"{label}: filling episode byte cache with {args.workers} workers")
|
||||
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
|
||||
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print(f"strategy: {label}")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"dataset_episodes: {dataset_episode_count}")
|
||||
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||
print(f"pool_episodes: {len(episodes)}")
|
||||
print(f"sampled_episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print(
|
||||
"| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |"
|
||||
)
|
||||
print("|---|---:|---:|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | "
|
||||
f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | "
|
||||
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |"
|
||||
)
|
||||
print()
|
||||
print("| Camera Job Stage | avg ms/job |")
|
||||
print("|---|---:|")
|
||||
print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |")
|
||||
print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |")
|
||||
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
|
||||
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
|
||||
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
|
||||
_print_range_timing_summary(fetch_pool)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
if args.include_decode:
|
||||
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log(f"{label}: running parallel video fetch + decode-only")
|
||||
parallel = run_parallel(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
_log(f"{label}: running overlapped end-to-end")
|
||||
overlapped = run_overlapped(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
args.prefetch_ahead,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
print(
|
||||
f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
||||
f"{parallel['fetch_s']:.2f} | "
|
||||
f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | "
|
||||
f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||
f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, "
|
||||
f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||
f"{overlapped['samples_s']:.1f} samples/s; video+decode "
|
||||
f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |"
|
||||
)
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> None:
|
||||
_log("starting_strategy: remote-decoder")
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log("remote-decoder: running direct source MP4 decoder")
|
||||
result = run_remote_decoder(
|
||||
meta,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
decode_workers=args.decode_workers,
|
||||
parquet_reader=parquet_reader,
|
||||
)
|
||||
print("strategy: remote-decoder")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {list(meta.video_keys)}")
|
||||
print()
|
||||
print("| Track | samples/s | notes |")
|
||||
print("|---|---:|---|")
|
||||
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
||||
print(
|
||||
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
||||
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.strategy == "full":
|
||||
args.strategy = "both"
|
||||
if args.strategy == "native-http":
|
||||
args.range_backend = "native-http"
|
||||
data_root = args.data_root
|
||||
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
parquet_reader = EpisodeParquetReader(meta, data_root)
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
||||
|
||||
if sidecar_path is not None:
|
||||
print(f"using_mp4_sidecar: {sidecar_path}")
|
||||
|
||||
if sidecar_path is not None and args.strategy == "both":
|
||||
if args.include_decode:
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend=args.range_backend,
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend=args.range_backend,
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-sidecar-native-http",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy == "both":
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
||||
print(
|
||||
"build_mp4_sidecar: "
|
||||
"uv run --no-sync python scripts/build_mp4_sidecar.py "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
if args.strategy in ("both", "indexed"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed",
|
||||
sidecar_path=None,
|
||||
)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode):
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
if args.strategy == "both" and args.include_decode:
|
||||
print()
|
||||
if args.strategy in ("both", "native-http"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http",
|
||||
sidecar_path=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--episodes", type=int, default=None)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def push_sidecar(local_path: str, data_root: str) -> list[str]:
|
||||
if not data_root.startswith("hf://"):
|
||||
return []
|
||||
|
||||
local = Path(local_path)
|
||||
fs = fsspec.filesystem("hf")
|
||||
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
|
||||
remote_paths = [f"{remote_dir}/{local.name}"]
|
||||
|
||||
for remote in remote_paths:
|
||||
fs.put(str(local), remote)
|
||||
return remote_paths
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
total = (
|
||||
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
|
||||
)
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
EpisodeVideoManifest.write_file_sidecar(
|
||||
args.output,
|
||||
rel_paths,
|
||||
args.data_root,
|
||||
range_backend=args.range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"wrote {args.output}")
|
||||
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
|
||||
if args.no_push:
|
||||
print("push_skipped: --no-push")
|
||||
else:
|
||||
pushed = push_sidecar(args.output, args.data_root)
|
||||
for remote in pushed:
|
||||
print(f"pushed {remote}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -281,7 +281,7 @@ class VideoFrameProvider:
|
||||
reencode_video(
|
||||
src,
|
||||
out_path,
|
||||
camera_encoder=encoder,
|
||||
video_encoder=encoder,
|
||||
overwrite=True,
|
||||
start_time_s=from_timestamp,
|
||||
end_time_s=to_timestamp,
|
||||
|
||||
@@ -54,6 +54,7 @@ from typing import Any
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
# Re-emit one row group per episode (a bulk pq.write_table would collapse
|
||||
# them into one). Write to a sibling tmp path and atomically rename so a
|
||||
# crash mid-write can't leave a half-written shard.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
write_table_one_row_group_per_episode(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
|
||||
@@ -105,8 +105,9 @@ def raw_observation_to_observation(
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
if image.dtype == torch.uint8:
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
@@ -436,7 +436,7 @@ class OpenCVCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
1. Reads a color frame (blocking call)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
@@ -445,8 +445,9 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
@@ -484,6 +485,8 @@ class OpenCVCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -268,13 +268,13 @@ class RealSenseCamera(Camera):
|
||||
)
|
||||
|
||||
if len(found_devices) > 1:
|
||||
serial_numbers = [dev["serial_number"] for dev in found_devices]
|
||||
serial_numbers = [dev["id"] for dev in found_devices]
|
||||
raise ValueError(
|
||||
f"Multiple RealSense cameras found with name '{name}'. "
|
||||
f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
|
||||
)
|
||||
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
serial_number = str(found_devices[0]["id"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
np.ndarray: The depth map as a NumPy array (height, width, 1)
|
||||
of type `np.uint16` (raw depth values in millimeters).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
1. Reads a color/depth frame (blocking call with 10s timeout)
|
||||
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -474,8 +474,9 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
@@ -486,6 +487,8 @@ class RealSenseCamera(Camera):
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
|
||||
processed_depth_frame = processed_depth_frame[..., np.newaxis]
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -522,6 +525,8 @@ class RealSenseCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive(): # pragma: no cover
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
@@ -532,7 +537,6 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -575,7 +579,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -611,6 +614,73 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[np.uint16]:
|
||||
"""Read the latest depth frame asynchronously, in millimeters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``, where each
|
||||
pixel is the distance from the sensor in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in millimeters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.uint16`` of shape ``(H, W, 1)``, where each pixel is the
|
||||
distance from the sensor in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
@@ -249,8 +249,9 @@ class ZMQCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
@@ -292,6 +293,8 @@ class ZMQCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -35,8 +35,11 @@ from .types import (
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -57,8 +60,12 @@ __all__ = [
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
"DepthEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
# Factories
|
||||
"encoder_config_from_video_info",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,6 +60,8 @@ class DatasetRecordConfig:
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -35,12 +35,17 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
|
||||
# Has no effect on datasets without depth cameras.
|
||||
depth_output_unit: str = "mm"
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.depth_output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar, Self
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -36,11 +36,12 @@ HW_VIDEO_CODECS = [
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
|
||||
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
|
||||
)
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
@@ -52,6 +53,19 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
# Default depth quantization and encoding parameters.
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
@@ -86,6 +100,10 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Source-data channel count this encoder is expected to handle (3 for RGB,
|
||||
# 1 for depth, etc.)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 3
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
@@ -94,9 +112,9 @@ class VideoEncoderConfig:
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Parse the ``video.*`` keys of a feature ``info`` block into
|
||||
constructor kwargs.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
@@ -115,7 +133,15 @@ class VideoEncoderConfig:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> Self:
|
||||
"""Reconstruct an encoder config from a video feature's ``info`` block.
|
||||
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
return cls(**cls._kwargs_from_video_info(video_info))
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
@@ -138,7 +164,9 @@ class VideoEncoderConfig:
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
check_video_encoder_parameters_pyav(
|
||||
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
|
||||
)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
@@ -218,6 +246,10 @@ class VideoEncoderConfig:
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
@@ -233,3 +265,75 @@ class VideoEncoderConfig:
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantizer.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"gray12le"``.
|
||||
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "gray12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
|
||||
``shift`` / ``use_log``) on top of the base parser. Missing keys
|
||||
fall back to the class defaults.
|
||||
"""
|
||||
kwargs = super()._kwargs_from_video_info(video_info)
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
|
||||
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Build the appropriate encoder config from a feature's ``info`` block.
|
||||
|
||||
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
|
||||
as a depth map and to :class:`VideoEncoderConfig`
|
||||
otherwise.
|
||||
|
||||
Args:
|
||||
video_info: A feature's ``info`` dict as persisted in ``info.json``,
|
||||
or ``None`` (treated as an empty dict).
|
||||
|
||||
Returns:
|
||||
A :class:`DepthEncoderConfig` for depth features, otherwise a
|
||||
:class:`VideoEncoderConfig`.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
|
||||
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
|
||||
return cls.from_video_info(video_info)
|
||||
|
||||
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_one_row_group_per_episode,
|
||||
to_parquet_with_hf_images,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
one_row_group_per_episode=True,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
one_row_group_per_episode: bool = False,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
|
||||
the episodes-metadata parquet (already one row per episode).
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx, (dst_chunk, dst_file)
|
||||
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
@@ -506,8 +506,10 @@ def compute_episode_stats(
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
For 'image'/'video' features, stats are computed per channel and kept with a
|
||||
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
|
||||
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
|
||||
this rescaling and remain in their stored units.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
@@ -531,8 +533,12 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = (
|
||||
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
)
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -552,8 +558,10 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
|
||||
raise ValueError(
|
||||
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
|
||||
)
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterable
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -337,6 +338,25 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos or images.
|
||||
|
||||
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||
"""
|
||||
|
||||
def _is_depth(ft: dict) -> bool:
|
||||
info = ft.get("info") or {}
|
||||
video_info = ft.get("video_info") or {}
|
||||
return (
|
||||
info.get("is_depth_map", False)
|
||||
or info.get("video.is_depth_map", False)
|
||||
or video_info.get("video.is_depth_map", False)
|
||||
)
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -580,29 +600,41 @@ class LeRobotDatasetMetadata:
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
preserve_keys: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
"""Populate or refresh per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Always re-probes the videos and overwrites existing info for every recomputed
|
||||
key. ``preserve_keys`` lists keys whose existing values must be kept (e.g.
|
||||
data-intrinsic entries like ``is_depth_map`` and depth quantization params)
|
||||
instead of being recomputed.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
video_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
preserve_keys: Keys whose existing values are kept instead of being
|
||||
recomputed. ``None`` (default) recomputes every key.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
preserve_set = set(preserve_keys or ())
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
existing = self.features[key].get("info") or {}
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
new_info = get_video_info(video_path, video_encoder=video_encoder)
|
||||
# Drop preserved keys so the existing values win on merge.
|
||||
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
|
||||
self.info.features[key]["info"] = {**existing, **new_info}
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -709,7 +741,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
|
||||
@@ -22,7 +22,10 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import DepthEncoderConfig
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
@@ -51,6 +54,7 @@ class DatasetReader:
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
@@ -68,6 +72,10 @@ class DatasetReader:
|
||||
relative timestamp offsets for temporal context windows.
|
||||
image_transforms: Optional torchvision v2 transform applied to
|
||||
visual features.
|
||||
return_uint8: If True, return RGB video frames as raw uint8 tensors
|
||||
instead of normalized float32.
|
||||
depth_output_unit: Physical unit depth maps are dequantized to
|
||||
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self.root = root
|
||||
@@ -76,6 +84,7 @@ class DatasetReader:
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
@@ -86,6 +95,12 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
@@ -247,7 +262,18 @@ class DatasetReader:
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
depth_encoder = self._depth_encoder_configs[vid_key]
|
||||
frames = dequantize_depth(
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_unit=self._depth_output_unit,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -27,6 +27,7 @@ import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -36,7 +37,14 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -47,6 +55,7 @@ from .compute_stats import (
|
||||
compute_relative_action_stats,
|
||||
)
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .image_writer import write_image
|
||||
from .io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
@@ -61,12 +70,13 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEPTH_FILE_PATTERN,
|
||||
IMAGE_FILE_PATTERN,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
@@ -600,7 +610,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -614,7 +624,7 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
camera_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
video_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
"""
|
||||
from fractions import Fraction
|
||||
|
||||
@@ -639,13 +649,13 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
codec_options = camera_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
codec_options = video_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(video_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = camera_encoder.pix_fmt
|
||||
v_out.pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -732,7 +742,7 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
video_encoder = encoder_config_from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
@@ -816,7 +826,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
camera_encoder,
|
||||
video_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1101,7 +1111,9 @@ def _copy_episodes_metadata_and_stats(
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info.features[key]["info"] = deepcopy(
|
||||
src_dataset.meta.info.features[key].get("info", {})
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
@@ -1150,15 +1162,15 @@ def _save_episode_images_for_video(
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
write_image(item[img_key], imgs_dir / frame_pattern.format(frame_index=i))
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
@@ -1190,13 +1202,14 @@ def _save_batch_episodes_images(
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
write_image(item[img_key_param], imgs_dir / frame_pattern.format(frame_index=base_frame_idx + i))
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
@@ -1287,7 +1300,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1301,7 +1314,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
video_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1326,10 +1339,11 @@ def _estimate_frame_size_via_calibration(
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames
|
||||
# Save calibration frames using the suffix/format the encoder expects.
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
for i, idx in enumerate(sample_indices):
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
write_image(hf_dataset[idx][img_key], calibration_dir / frame_pattern.format(frame_index=i))
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
@@ -1337,7 +1351,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1610,6 +1624,7 @@ def recompute_stats(
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
# TODO: enable image and video stats re-computation
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
@@ -1656,6 +1671,7 @@ def convert_image_to_video_dataset(
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1667,21 +1683,32 @@ def convert_image_to_video_dataset(
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder: Video encoder settings
|
||||
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
dataset: The source LeRobot dataset with images.
|
||||
output_dir: Root directory where the converted dataset will be stored. When
|
||||
``None``, defaults to ``$HF_LEROBOT_HOME/repo_id``. Equivalent to
|
||||
``new_root`` in ``EditDatasetConfig``.
|
||||
repo_id: Converted dataset identifier. Equivalent to ``new_repo_id`` in
|
||||
``EditDatasetConfig``.
|
||||
camera_encoder: Video encoder settings applied to RGB cameras. When ``None``,
|
||||
:func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to depth-map cameras, including
|
||||
the quantization parameters persisted to the dataset metadata. When
|
||||
``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
episode_indices: Episode indices to convert. When ``None``, all episodes are
|
||||
converted.
|
||||
num_workers: Number of threads for parallel processing.
|
||||
max_episodes_per_batch: Maximum episodes per video batch, to bound memory use.
|
||||
``None`` means no limit.
|
||||
max_frames_per_batch: Maximum frames per video batch, to bound memory use.
|
||||
``None`` means no limit.
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
A new :class:`LeRobotDataset` with images encoded as videos.
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
if depth_encoder is None:
|
||||
depth_encoder = depth_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
@@ -1706,10 +1733,7 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
)
|
||||
logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1771,6 +1795,8 @@ def convert_image_to_video_dataset(
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
@@ -1779,7 +1805,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=target_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1821,7 +1847,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=target_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1860,16 +1886,11 @@ def convert_image_to_video_dataset(
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
# Update video info for all image keys (now videos). They are registered as
|
||||
# video features above, so update_video_info populates their (still-empty) info.
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
new_meta.update_video_info(video_key=img_key, video_encoder=target_encoder)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
@@ -1896,11 +1917,11 @@ def convert_image_to_video_dataset(
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
video_path, video_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -1909,7 +1930,8 @@ def _reencode_video_worker(args: tuple) -> Path:
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
@@ -1920,8 +1942,11 @@ def reencode_dataset(
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
camera_encoder: Target encoder configuration applied to every RGB video
|
||||
file. If ``None``, re-encoding is skipped for RGB videos.
|
||||
depth_encoder: Target encoder configuration applied to every depth video
|
||||
file. If ``None``, re-encoding is skipped for depth videos.
|
||||
Quantization parameters will not override the ones in the current dataset.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
@@ -1933,23 +1958,35 @@ def reencode_dataset(
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
video_keys_encoders_dict = {}
|
||||
video_keys_paths_dict = {}
|
||||
|
||||
if camera_encoder is None and depth_encoder is None:
|
||||
raise ValueError("Either camera_encoder or depth_encoder must be provided")
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
current_encoder = encoder_config_from_video_info(current_info)
|
||||
target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
|
||||
if target_encoder is None:
|
||||
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
|
||||
elif current_encoder != target_encoder:
|
||||
video_keys_paths_dict[video_key] = list((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
video_keys_encoders_dict[video_key] = target_encoder
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
if len(video_keys_paths_dict) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
worker_args = [
|
||||
(path, encoder, encoder_threads)
|
||||
for video_key, encoder in video_keys_encoders_dict.items()
|
||||
for path in video_keys_paths_dict[video_key]
|
||||
]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
@@ -1963,10 +2000,15 @@ def reencode_dataset(
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
# Refresh video info in metadata for every re-encoded key. Re-encoding only
|
||||
# changes codec/container params, so for depth videos we preserve ``is_depth_map``
|
||||
# and the depth quantization params (``video.depth_min`` / ``video.depth_max`` /
|
||||
# ...), which describe the data rather than the codec and must survive a transcode.
|
||||
# RGB videos pass an empty set: still a refresh, but nothing to preserve.
|
||||
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
|
||||
for video_key, encoder in video_keys_encoders_dict.items():
|
||||
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else set()
|
||||
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
@@ -31,7 +31,12 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
@@ -48,6 +53,7 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
@@ -67,17 +73,22 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
path_template = (
|
||||
DEFAULT_DEPTH_PATH
|
||||
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
|
||||
else DEFAULT_IMAGE_PATH
|
||||
)
|
||||
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -97,6 +108,7 @@ class DatasetWriter:
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
@@ -108,8 +120,11 @@ class DatasetWriter:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
camera_encoder: Video encoder settings applied to RGB cameras. When
|
||||
``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to depth cameras, including
|
||||
the quantization parameters. When ``None``,
|
||||
:func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
@@ -121,6 +136,7 @@ class DatasetWriter:
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -145,7 +161,8 @@ class DatasetWriter:
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -195,6 +212,7 @@ class DatasetWriter:
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self._meta.video_keys),
|
||||
depth_video_keys=list(self._meta.depth_keys),
|
||||
temp_dir=self._root,
|
||||
)
|
||||
|
||||
@@ -282,10 +300,13 @@ class DatasetWriter:
|
||||
if use_streaming:
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self._meta.video_keys:
|
||||
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
k: v
|
||||
if k == "count"
|
||||
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
@@ -300,7 +321,9 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -511,7 +534,12 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
self._meta.update_video_info(
|
||||
video_key,
|
||||
video_encoder=self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -578,13 +606,14 @@ class DatasetWriter:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
|
||||
is_depth = video_key in self._meta.depth_keys
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._depth_encoder if is_depth else self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
video_backend: Video backend to use for encoding. Defaults to "pyav".
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = True,
|
||||
output_channel_last: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Accepted input layouts :
|
||||
|
||||
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
|
||||
- ``(..., 1, H, W)`` — batched frames with channel-first.
|
||||
- ``(..., H, W, 1)`` — batched frames with channel-last.
|
||||
Output layout is determined by ``output_channel_last``.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
|
||||
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
|
||||
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
|
||||
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
|
||||
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
# Compute the scale and offset first.
|
||||
depth_min_m = float(depth_min)
|
||||
depth_max_m = float(depth_max)
|
||||
shift_m = float(shift)
|
||||
if use_log:
|
||||
log_min = math.log(depth_min_m + shift_m)
|
||||
log_max = math.log(depth_max_m + shift_m)
|
||||
scale = (log_max - log_min) / DEPTH_QMAX
|
||||
offset = log_min
|
||||
else:
|
||||
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
|
||||
offset = depth_min_m
|
||||
|
||||
# ── Torch path: stay on the input device, single fp32 allocation. ────────
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
if quantized.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
|
||||
|
||||
# Single allocation we own; everything else is in-place.
|
||||
buf = quantized.to(dtype=torch.float32, copy=True)
|
||||
buf.mul_(scale).add_(offset)
|
||||
if use_log:
|
||||
buf.exp_().sub_(shift_m)
|
||||
buf.clamp_(depth_min_m, depth_max_m)
|
||||
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return buf if output_tensor else buf.cpu().numpy()
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
# when returning a tensor (torch.uint16 is poorly supported).
|
||||
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
if output_tensor:
|
||||
return buf
|
||||
return buf.cpu().numpy().astype(np.uint16, copy=False)
|
||||
|
||||
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
|
||||
arr = np.asarray(quantized)
|
||||
if arr.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
|
||||
|
||||
buf = np.empty(arr.shape, dtype=np.float32)
|
||||
np.multiply(arr, scale, out=buf)
|
||||
np.add(buf, offset, out=buf)
|
||||
if use_log:
|
||||
np.exp(buf, out=buf)
|
||||
np.subtract(buf, shift_m, out=buf)
|
||||
np.clip(buf, depth_min_m, depth_max_m, out=buf)
|
||||
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
np.rint(buf, out=buf)
|
||||
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited; return float32 millimetres.
|
||||
return torch.from_numpy(buf)
|
||||
return buf.astype(np.uint16, copy=False)
|
||||
@@ -1,903 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import quote, urljoin, urlparse
|
||||
|
||||
import fsspec
|
||||
import httpx
|
||||
import numpy as np
|
||||
from huggingface_hub import HfApi, HfFileSystem, constants
|
||||
from huggingface_hub.utils import hf_raise_for_status
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeVideoSpan:
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
frame_count: int
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoFileRecord:
|
||||
file_path: str
|
||||
file_size: int
|
||||
mp4: Mp4Index
|
||||
|
||||
|
||||
class ThreadLocalRangeFetcher:
|
||||
"""Range reader that gives each worker thread independent file handles."""
|
||||
|
||||
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
protocol = "hf" if self.data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self.block_size = block_size
|
||||
self.cache_type = cache_type
|
||||
self._local = threading.local()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_open_s": 0.0,
|
||||
"range_seek_s": 0.0,
|
||||
"range_read_s": 0.0,
|
||||
}
|
||||
|
||||
def _url(self, relative_path: str) -> str:
|
||||
if self.data_root.startswith("hf://"):
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
return str(Path(self.data_root) / relative_path)
|
||||
|
||||
def _handle(self, relative_path: str):
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
handles = {}
|
||||
self._local.handles = handles
|
||||
handle = handles.get(relative_path)
|
||||
if handle is None or getattr(handle, "closed", False):
|
||||
handle = self.fs.open(
|
||||
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
|
||||
)
|
||||
handles[relative_path] = handle
|
||||
return handle
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
return int(self.fs.info(self._url(relative_path))["size"])
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
open_start = time.perf_counter()
|
||||
handle = self._handle(relative_path)
|
||||
open_s = time.perf_counter() - open_start
|
||||
seek_start = time.perf_counter()
|
||||
handle.seek(offset)
|
||||
seek_s = time.perf_counter() - seek_start
|
||||
read_start = time.perf_counter()
|
||||
data = handle.read(length)
|
||||
read_s = time.perf_counter() - read_start
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(data)),
|
||||
range_open_s=open_s,
|
||||
range_seek_s=seek_s,
|
||||
range_read_s=read_s,
|
||||
)
|
||||
return data
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] = self._timing_totals.get(key, 0.0) + value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
return
|
||||
for handle in handles.values():
|
||||
with contextlib.suppress(Exception):
|
||||
handle.close()
|
||||
handles.clear()
|
||||
|
||||
|
||||
class NativeHTTPRangeFetcher:
|
||||
"""Direct pooled HTTP range reader for hf:// paths."""
|
||||
|
||||
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
|
||||
_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
_RETRYABLE_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.PoolTimeout,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
max_connections: int = 32,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 4,
|
||||
):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
if not self.data_root.startswith("hf://"):
|
||||
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
|
||||
self.max_retries = max_retries
|
||||
self.api = HfApi()
|
||||
self.fs: HfFileSystem | None = None
|
||||
self._bucket_id: str | None = None
|
||||
self._bucket_prefix = ""
|
||||
if self.data_root.startswith("hf://buckets/"):
|
||||
bucket_root = self.data_root.removeprefix("hf://buckets/")
|
||||
parts = bucket_root.split("/", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid bucket root: {self.data_root}")
|
||||
self._bucket_id = f"{parts[0]}/{parts[1]}"
|
||||
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
|
||||
else:
|
||||
self.fs = HfFileSystem()
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
|
||||
follow_redirects=False,
|
||||
)
|
||||
self._resolved_urls: dict[str, str] = {}
|
||||
self._source_urls: dict[str, str] = {}
|
||||
self._sizes: dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_resolve_s": 0.0,
|
||||
"range_header_s": 0.0,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
"range_retry_attempts": 0.0,
|
||||
"range_retry_sleep_s": 0.0,
|
||||
"range_failed_requests": 0.0,
|
||||
}
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self.client.request(method, url, **kwargs)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _cache_key(self, relative_path: str) -> tuple[str, str]:
|
||||
return self.data_root, relative_path
|
||||
|
||||
def _path(self, relative_path: str) -> str:
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
|
||||
def _bucket_path(self, relative_path: str) -> str:
|
||||
if self._bucket_prefix:
|
||||
return f"{self._bucket_prefix}/{relative_path}"
|
||||
return relative_path
|
||||
|
||||
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
|
||||
headers = self.api._build_hf_headers()
|
||||
if urlparse(request_url).netloc != urlparse(source_url).netloc:
|
||||
headers.pop("authorization", None)
|
||||
headers.pop("Authorization", None)
|
||||
return headers
|
||||
|
||||
def _source_url(self, relative_path: str) -> str:
|
||||
with self._lock:
|
||||
source = self._source_urls.get(relative_path)
|
||||
if source is not None:
|
||||
return source
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
source = self._GLOBAL_SOURCE_URLS.get(key)
|
||||
if source is None:
|
||||
if self._bucket_id is not None:
|
||||
source = (
|
||||
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
|
||||
f"{quote(self._bucket_path(relative_path))}"
|
||||
)
|
||||
else:
|
||||
if self.fs is None:
|
||||
raise RuntimeError("HfFileSystem fallback was not initialized")
|
||||
source = self.fs.url(self._path(relative_path))
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SOURCE_URLS[key] = source
|
||||
with self._lock:
|
||||
self._source_urls[relative_path] = source
|
||||
return source
|
||||
|
||||
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
|
||||
with self._lock:
|
||||
if not refresh and relative_path in self._resolved_urls:
|
||||
return self._resolved_urls[relative_path]
|
||||
key = self._cache_key(relative_path)
|
||||
if not refresh:
|
||||
with self._GLOBAL_LOCK:
|
||||
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if resolved is not None:
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if size is not None:
|
||||
self._sizes[relative_path] = size
|
||||
return resolved
|
||||
|
||||
source = self._source_url(relative_path)
|
||||
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
location = response.headers.get("Location")
|
||||
resolved = urljoin(source, location) if location else source
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._sizes[relative_path] = int(response.headers["Content-Length"])
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_RESOLVED_URLS[key] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
|
||||
return resolved
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
with self._lock:
|
||||
size = self._sizes.get(relative_path)
|
||||
if size is not None:
|
||||
return size
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if size is not None:
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
return size
|
||||
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
response = self._request(
|
||||
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||
)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
size = int(response.headers["Content-Length"])
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SIZES[key] = size
|
||||
return size
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
resolve_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
resolve_s = time.perf_counter() - resolve_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
payload, status_code, timings = self._read_range_response(resolved, headers)
|
||||
if status_code == 403:
|
||||
refresh_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path, refresh=True)
|
||||
resolve_s += time.perf_counter() - refresh_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
|
||||
for key, value in retry_timings.items():
|
||||
timings[key] += value
|
||||
if status_code == 403:
|
||||
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(payload)),
|
||||
range_resolve_s=resolve_s,
|
||||
**{f"range_status_{status_code}": 1.0},
|
||||
**timings,
|
||||
)
|
||||
return payload
|
||||
|
||||
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
|
||||
last_exc: Exception | None = None
|
||||
retry_attempts = 0.0
|
||||
retry_sleep_s = 0.0
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
payload, status_code, timings = self._read_range_response_once(url, headers)
|
||||
timings["range_retry_attempts"] = retry_attempts
|
||||
timings["range_retry_sleep_s"] = retry_sleep_s
|
||||
return payload, status_code, timings
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
retry_attempts += 1.0
|
||||
sleep_s = min(0.5 * 2**attempt, 5.0)
|
||||
retry_sleep_s += sleep_s
|
||||
time.sleep(sleep_s)
|
||||
self._record_timing(
|
||||
range_failed_requests=1.0,
|
||||
range_retry_attempts=retry_attempts,
|
||||
range_retry_sleep_s=retry_sleep_s,
|
||||
)
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP range request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _read_range_response_once(
|
||||
self, url: str, headers: dict[str, str]
|
||||
) -> tuple[bytes, int, dict[str, float]]:
|
||||
header_start = time.perf_counter()
|
||||
with self.client.stream("GET", url, headers=headers) as response:
|
||||
header_s = time.perf_counter() - header_start
|
||||
if response.status_code == 403:
|
||||
return (
|
||||
b"",
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
},
|
||||
)
|
||||
hf_raise_for_status(response)
|
||||
chunks = []
|
||||
first_byte_s = 0.0
|
||||
first_chunk = True
|
||||
chunk_gap_s = 0.0
|
||||
chunk_count = 0.0
|
||||
previous_chunk_at = body_start = time.perf_counter()
|
||||
for chunk in response.iter_bytes():
|
||||
now = time.perf_counter()
|
||||
if first_chunk:
|
||||
first_byte_s = now - body_start
|
||||
first_chunk = False
|
||||
chunk_gap_s += now - previous_chunk_at
|
||||
previous_chunk_at = now
|
||||
chunk_count += 1.0
|
||||
chunks.append(chunk)
|
||||
body_s = time.perf_counter() - body_start
|
||||
join_start = time.perf_counter()
|
||||
payload = b"".join(chunks)
|
||||
join_s = time.perf_counter() - join_start
|
||||
return (
|
||||
payload,
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": first_byte_s,
|
||||
"range_body_s": body_s,
|
||||
"range_join_s": join_s,
|
||||
"range_chunks": chunk_count,
|
||||
"range_chunk_gap_s": chunk_gap_s,
|
||||
},
|
||||
)
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] = self._timing_totals.get(key, 0.0) + value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
|
||||
def make_range_fetcher(
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
):
|
||||
if range_backend == "fsspec":
|
||||
return ThreadLocalRangeFetcher(data_root)
|
||||
if range_backend == "native-http":
|
||||
max_connections = native_http_connections or max(8, workers)
|
||||
return NativeHTTPRangeFetcher(
|
||||
data_root,
|
||||
max_connections=max_connections,
|
||||
timeout=native_http_timeout,
|
||||
max_retries=native_http_retries,
|
||||
)
|
||||
raise ValueError(f"Unknown range backend: {range_backend}")
|
||||
|
||||
|
||||
class EpisodeVideoManifest:
|
||||
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
|
||||
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
files: list[VideoFileRecord],
|
||||
spans: dict[str, np.ndarray],
|
||||
):
|
||||
self.video_keys = list(video_keys)
|
||||
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
|
||||
self.files = files
|
||||
self.spans = spans
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
episode_indices: list[int] | range | None = None,
|
||||
range_backend: str = "fsspec",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
sidecar_path: str | Path | None = None,
|
||||
) -> EpisodeVideoManifest:
|
||||
meta.ensure_readable()
|
||||
video_keys = list(meta.video_keys)
|
||||
if episode_indices is None:
|
||||
episode_indices = range(int(meta.total_episodes))
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
|
||||
)
|
||||
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
|
||||
if sidecar_path is None:
|
||||
files = cls._build_file_records(
|
||||
rel_paths,
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
else:
|
||||
records = cls.load_file_sidecar(sidecar_path)
|
||||
missing = [path for path in rel_paths if path not in records]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
|
||||
)
|
||||
files = [records[path] for path in rel_paths]
|
||||
|
||||
total = int(meta.total_episodes)
|
||||
num_cameras = len(video_keys)
|
||||
spans: dict[str, np.ndarray] = {
|
||||
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
}
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep = meta.episodes[ep_idx]
|
||||
for cam_idx, key in enumerate(video_keys):
|
||||
rel_path = str(meta.get_video_file_path(ep_idx, key))
|
||||
file_id = path_to_id[rel_path]
|
||||
mp4 = files[file_id].mp4
|
||||
from_ts = float(ep[f"videos/{key}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{key}/to_timestamp"])
|
||||
sample_slice = mp4.sample_slice(
|
||||
from_ts,
|
||||
to_ts,
|
||||
keyframe_pad_s=keyframe_pad_s,
|
||||
keyframe_pad_fraction=keyframe_pad_fraction,
|
||||
file_size=files[file_id].file_size,
|
||||
)
|
||||
spans["file_id"][ep_idx, cam_idx] = file_id
|
||||
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
|
||||
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
|
||||
spans["first_pts"][ep_idx, cam_idx] = from_ts
|
||||
spans["last_pts"][ep_idx, cam_idx] = to_ts
|
||||
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
|
||||
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
|
||||
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
|
||||
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
|
||||
|
||||
return cls(video_keys=video_keys, files=files, spans=spans)
|
||||
|
||||
@staticmethod
|
||||
def _build_file_records(
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
header_probe_bytes: int,
|
||||
max_probe_bytes: int,
|
||||
) -> list[VideoFileRecord]:
|
||||
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||
|
||||
def build_file(path: str) -> VideoFileRecord:
|
||||
file_size = fetcher.info_size(path)
|
||||
mp4 = fetch_mp4_index(
|
||||
path,
|
||||
fetcher.read_range,
|
||||
file_size=file_size,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
return VideoFileRecord(path, file_size, mp4)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
return list(pool.map(build_file, rel_paths))
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
@classmethod
|
||||
def write_file_sidecar(
|
||||
cls,
|
||||
sidecar_path: str | Path,
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str = "native-http",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> None:
|
||||
records = cls._build_file_records(
|
||||
sorted(set(rel_paths)),
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
cls.save_file_sidecar(sidecar_path, records)
|
||||
|
||||
@staticmethod
|
||||
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
|
||||
sidecar_path = Path(sidecar_path)
|
||||
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"version": 1,
|
||||
"files": [
|
||||
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
|
||||
for record in records
|
||||
],
|
||||
}
|
||||
arrays = {}
|
||||
for file_idx, record in enumerate(records):
|
||||
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
|
||||
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
|
||||
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
|
||||
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
|
||||
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
|
||||
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
|
||||
|
||||
@staticmethod
|
||||
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
|
||||
cache_key = str(Path(sidecar_path).expanduser())
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
with np.load(sidecar_path, allow_pickle=False) as data:
|
||||
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
|
||||
records = {}
|
||||
for file_idx, item in enumerate(payload["files"]):
|
||||
arrays = {
|
||||
name: data[f"{file_idx}/{name}"]
|
||||
for name in [
|
||||
"sample_pts",
|
||||
"sample_durations",
|
||||
"sample_sizes",
|
||||
"sample_offsets",
|
||||
"sync_samples",
|
||||
]
|
||||
}
|
||||
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
|
||||
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
|
||||
return records
|
||||
|
||||
def camera_id(self, camera_key: str) -> int:
|
||||
return self._camera_to_id[camera_key]
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
|
||||
cam = self.camera_id(camera_key)
|
||||
return EpisodeVideoSpan(
|
||||
file_id=int(self.spans["file_id"][episode_index, cam]),
|
||||
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
|
||||
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
|
||||
first_pts=float(self.spans["first_pts"][episode_index, cam]),
|
||||
last_pts=float(self.spans["last_pts"][episode_index, cam]),
|
||||
frame_count=int(self.spans["frame_count"][episode_index, cam]),
|
||||
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
|
||||
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
|
||||
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> VideoFileRecord:
|
||||
return self.files[file_id]
|
||||
|
||||
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
|
||||
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
|
||||
|
||||
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
|
||||
span = self.lookup(episode_index, camera_key)
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
def __init__(
|
||||
self,
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
byte_budget: int = 80 * 1024**3,
|
||||
workers: int = 8,
|
||||
range_backend: str = "fsspec",
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
open_decoders: bool = True,
|
||||
):
|
||||
self.manifest = manifest
|
||||
self.fetcher = make_range_fetcher(
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
native_http_connections=native_http_connections,
|
||||
native_http_timeout=native_http_timeout,
|
||||
native_http_retries=native_http_retries,
|
||||
)
|
||||
self.byte_budget = byte_budget
|
||||
self.open_decoders = open_decoders
|
||||
self._pool = ThreadPoolExecutor(max_workers=workers)
|
||||
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
|
||||
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
|
||||
self._bytes = 0
|
||||
self._lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"lookup_s": 0.0,
|
||||
"fetch_s": 0.0,
|
||||
"synthesize_s": 0.0,
|
||||
"store_s": 0.0,
|
||||
"jobs": 0.0,
|
||||
}
|
||||
|
||||
def close(self) -> None:
|
||||
self._pool.shutdown(wait=True)
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._futures.clear()
|
||||
self._bytes = 0
|
||||
self.fetcher.close()
|
||||
|
||||
def __enter__(self) -> EpisodeByteCache:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc) -> None:
|
||||
self.close()
|
||||
|
||||
def submit_prefetch(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self._submit(episode_index, camera_key)
|
||||
|
||||
def ensure_ready(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self.get_bytes(episode_index, camera_key)
|
||||
|
||||
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
|
||||
return self._get_entry(episode_index, camera_key)["bytes"]
|
||||
|
||||
def get_decoder(self, episode_index: int, camera_key: str):
|
||||
entry = self._get_entry(episode_index, camera_key)
|
||||
decoder = entry.get("decoder")
|
||||
if decoder is None:
|
||||
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
|
||||
entry["decoder"] = decoder
|
||||
return decoder
|
||||
|
||||
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
local_ts = [ts - span.source_start_pts for ts in timestamps]
|
||||
decoder = self.get_decoder(episode_index, camera_key)
|
||||
if hasattr(decoder, "get_frames_played_at"):
|
||||
return decoder.get_frames_played_at(local_ts).data
|
||||
metadata = decoder.metadata
|
||||
fps = getattr(metadata, "average_fps", None)
|
||||
if fps is None:
|
||||
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
|
||||
fps = metadata.num_frames / duration
|
||||
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._lock:
|
||||
summary = dict(self._timing_totals)
|
||||
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
|
||||
if fetcher_summary is not None:
|
||||
summary.update(fetcher_summary())
|
||||
return summary
|
||||
|
||||
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
future: Future[dict[str, Any]] = Future()
|
||||
future.set_result(self._cache[key])
|
||||
return future
|
||||
future = self._futures.get(key)
|
||||
if future is None:
|
||||
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
|
||||
self._futures[key] = future
|
||||
return future
|
||||
|
||||
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
entry = self._cache.get(key)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return entry
|
||||
future = self._submit(episode_index, camera_key)
|
||||
entry = future.result()
|
||||
store_start = time.perf_counter()
|
||||
with self._lock:
|
||||
self._futures.pop(key, None)
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return existing
|
||||
self._cache[key] = entry
|
||||
self._bytes += len(entry["bytes"])
|
||||
self._evict_locked()
|
||||
timings = entry.pop("_timings", None)
|
||||
if timings is not None:
|
||||
self._timing_totals["lookup_s"] += timings["lookup_s"]
|
||||
self._timing_totals["fetch_s"] += timings["fetch_s"]
|
||||
self._timing_totals["synthesize_s"] += timings["synthesize_s"]
|
||||
self._timing_totals["store_s"] += time.perf_counter() - store_start
|
||||
self._timing_totals["jobs"] += 1
|
||||
return entry
|
||||
|
||||
def _evict_locked(self) -> None:
|
||||
while self._bytes > self.byte_budget and self._cache:
|
||||
_key, entry = self._cache.popitem(last=False)
|
||||
self._bytes -= len(entry["bytes"])
|
||||
|
||||
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
lookup_start = time.perf_counter()
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
file_record = self.manifest.file_lookup(span.file_id)
|
||||
sample_slice = Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
lookup_s = time.perf_counter() - lookup_start
|
||||
fetch_start = time.perf_counter()
|
||||
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
|
||||
fetch_s = time.perf_counter() - fetch_start
|
||||
if len(payload) != span.mdat_length:
|
||||
raise OSError(
|
||||
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
|
||||
)
|
||||
synthesize_start = time.perf_counter()
|
||||
mp4_bytes = synthesize_mp4(file_record.mp4, sample_slice, payload)
|
||||
synthesize_s = time.perf_counter() - synthesize_start
|
||||
entry: dict[str, Any] = {
|
||||
"bytes": mp4_bytes,
|
||||
"decoder": None,
|
||||
"_timings": {
|
||||
"lookup_s": lookup_s,
|
||||
"fetch_s": fetch_s,
|
||||
"synthesize_s": synthesize_s,
|
||||
},
|
||||
}
|
||||
if self.open_decoders:
|
||||
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
|
||||
return entry
|
||||
|
||||
|
||||
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
|
||||
if frame_mappings is not None:
|
||||
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
|
||||
|
||||
|
||||
def assert_hf_hub_range_cache_branch() -> None:
|
||||
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
|
||||
|
||||
try:
|
||||
dist = metadata.distribution("huggingface_hub")
|
||||
except metadata.PackageNotFoundError as exc:
|
||||
raise AssertionError("huggingface_hub is not installed") from exc
|
||||
|
||||
candidates = []
|
||||
direct_url = dist.read_text("direct_url.json")
|
||||
if direct_url:
|
||||
candidates.append(direct_url)
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
parsed = json.loads(direct_url)
|
||||
candidates.append(str(parsed.get("url", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
|
||||
|
||||
text = "\n".join(candidates)
|
||||
if "feat/hffs-cache-cdn-range-reads" not in text:
|
||||
raise AssertionError(
|
||||
"huggingface_hub must be installed from "
|
||||
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageTimer:
|
||||
fetch_ms: float = 0.0
|
||||
decode_ms: float = 0.0
|
||||
bytes_read: int = 0
|
||||
misses: int = 0
|
||||
|
||||
def record_fetch(self, start: float, byte_count: int) -> None:
|
||||
self.fetch_ms += (time.perf_counter() - start) * 1000
|
||||
self.bytes_read += byte_count
|
||||
self.misses += 1
|
||||
@@ -96,6 +96,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
depth_output_unit=cfg.dataset.depth_output_unit,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -336,7 +336,7 @@ def validate_feature_image_or_video(
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -42,10 +42,41 @@ def safe_stop_image_writer(func):
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
# TODO(CarolinePascal): 4 dimensions RGB-D images
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in [np.uint16, np.float32]:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
|
||||
)
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
@@ -71,13 +102,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -101,7 +147,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import datasets
|
||||
import numpy as np
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
return items_dict
|
||||
|
||||
|
||||
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
||||
"""Write ``table`` with one parquet row group per episode (in episode order).
|
||||
|
||||
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
|
||||
mirroring the recording writer. ``table`` must carry a contiguous
|
||||
``episode_index`` column.
|
||||
"""
|
||||
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
||||
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
||||
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
||||
try:
|
||||
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
|
||||
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
Images are embedded into the arrow table first (``ParquetWriter.write_table``
|
||||
does not embed external image files like ``Dataset.to_parquet`` does).
|
||||
``features`` types image columns as ``Image()`` in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
ds = embed_images(ds)
|
||||
table = ds.with_format("arrow")[:]
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
# No episode boundaries to align row groups to — keep a single write.
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -58,8 +58,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -186,6 +188,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
@@ -208,6 +213,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
@@ -248,6 +254,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
|
||||
# Load actual data
|
||||
@@ -273,6 +280,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
depth_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
@@ -280,6 +288,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -315,6 +324,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
return self.reader
|
||||
|
||||
@@ -322,12 +332,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -646,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -678,6 +691,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
@@ -712,6 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
@@ -721,12 +737,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -750,6 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -779,6 +797,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
@@ -806,6 +826,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
|
||||
if obj._requested_root is not None:
|
||||
@@ -825,12 +846,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -1,666 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Box:
|
||||
type: bytes
|
||||
start: int
|
||||
header_size: int
|
||||
end: int
|
||||
|
||||
@property
|
||||
def payload_start(self) -> int:
|
||||
return self.start + self.header_size
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4SampleSlice:
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
byte_offset: int
|
||||
byte_length: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4Index:
|
||||
file_path: str
|
||||
file_size: int
|
||||
ftyp: bytes
|
||||
moov_offset: int
|
||||
mdat_offset: int
|
||||
mdat_payload_offset: int
|
||||
mdat_payload_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
timescale: int
|
||||
duration: int
|
||||
track_id: int
|
||||
width: int
|
||||
height: int
|
||||
stsd_body: bytes
|
||||
sample_pts: np.ndarray
|
||||
sample_durations: np.ndarray
|
||||
sample_sizes: np.ndarray
|
||||
sample_offsets: np.ndarray
|
||||
sync_samples: np.ndarray
|
||||
|
||||
def sample_slice(
|
||||
self,
|
||||
from_ts: float,
|
||||
to_ts: float,
|
||||
*,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
file_size: int | None = None,
|
||||
) -> Mp4SampleSlice:
|
||||
if to_ts < from_ts:
|
||||
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
|
||||
if len(self.sample_pts) == 0:
|
||||
raise ValueError(f"{self.file_path} contains no indexed samples")
|
||||
|
||||
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
|
||||
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
|
||||
lo = min(max(lo, 0), len(self.sample_pts) - 1)
|
||||
hi = min(max(hi, lo), len(self.sample_pts) - 1)
|
||||
|
||||
if len(self.sync_samples):
|
||||
prev_sync = self.sync_samples[self.sync_samples <= lo]
|
||||
if len(prev_sync):
|
||||
lo = int(prev_sync[-1])
|
||||
else:
|
||||
lo = int(self.sync_samples[0])
|
||||
if lo > hi:
|
||||
hi = lo
|
||||
|
||||
offsets = self.sample_offsets[lo : hi + 1]
|
||||
sizes = self.sample_sizes[lo : hi + 1]
|
||||
slice_lo = int(offsets.min())
|
||||
slice_hi = int((offsets + sizes).max())
|
||||
if file_size is not None:
|
||||
slice_hi = min(slice_hi, int(file_size))
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=lo,
|
||||
sample_hi=hi,
|
||||
byte_offset=slice_lo,
|
||||
byte_length=slice_hi - slice_lo,
|
||||
source_start_pts=float(self.sample_pts[lo]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"file_size": self.file_size,
|
||||
"ftyp": self.ftyp.hex(),
|
||||
"moov_offset": self.moov_offset,
|
||||
"mdat_offset": self.mdat_offset,
|
||||
"mdat_payload_offset": self.mdat_payload_offset,
|
||||
"mdat_payload_size": self.mdat_payload_size,
|
||||
"faststart": self.faststart,
|
||||
"codec": self.codec,
|
||||
"timescale": self.timescale,
|
||||
"duration": self.duration,
|
||||
"track_id": self.track_id,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"stsd_body": self.stsd_body.hex(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
|
||||
return cls(
|
||||
file_path=data["file_path"],
|
||||
file_size=int(data["file_size"]),
|
||||
ftyp=bytes.fromhex(data["ftyp"]),
|
||||
moov_offset=int(data["moov_offset"]),
|
||||
mdat_offset=int(data["mdat_offset"]),
|
||||
mdat_payload_offset=int(data["mdat_payload_offset"]),
|
||||
mdat_payload_size=int(data["mdat_payload_size"]),
|
||||
faststart=bool(data["faststart"]),
|
||||
codec=data["codec"],
|
||||
timescale=int(data["timescale"]),
|
||||
duration=int(data["duration"]),
|
||||
track_id=int(data["track_id"]),
|
||||
width=int(data["width"]),
|
||||
height=int(data["height"]),
|
||||
stsd_body=bytes.fromhex(data["stsd_body"]),
|
||||
sample_pts=arrays["sample_pts"],
|
||||
sample_durations=arrays["sample_durations"],
|
||||
sample_sizes=arrays["sample_sizes"],
|
||||
sample_offsets=arrays["sample_offsets"],
|
||||
sync_samples=arrays["sync_samples"],
|
||||
)
|
||||
|
||||
|
||||
def fetch_mp4_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
*,
|
||||
file_size: int,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> Mp4Index:
|
||||
probe_size = min(header_probe_bytes, file_size)
|
||||
while True:
|
||||
data = read_range(path, 0, probe_size)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
has_mdat = any(box.type == b"mdat" for box in top)
|
||||
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
|
||||
if has_mdat and has_moov:
|
||||
return parse_mp4_index(path, data, file_size=file_size)
|
||||
if probe_size >= min(max_probe_bytes, file_size):
|
||||
if has_mdat and not has_moov:
|
||||
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
|
||||
if tail_index is not None:
|
||||
return tail_index
|
||||
missing = []
|
||||
if not has_mdat:
|
||||
missing.append("mdat")
|
||||
if not has_moov:
|
||||
missing.append("moov")
|
||||
raise ValueError(
|
||||
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
|
||||
)
|
||||
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
|
||||
|
||||
|
||||
def _fetch_tail_moov_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
prefix: bytes,
|
||||
top_boxes: list[Box],
|
||||
file_size: int,
|
||||
max_probe_bytes: int,
|
||||
) -> Mp4Index | None:
|
||||
mdat_box = _one(top_boxes, b"mdat")
|
||||
if mdat_box is None or mdat_box.end >= file_size:
|
||||
return None
|
||||
tail_offset = mdat_box.end
|
||||
tail_length = min(max_probe_bytes, file_size - tail_offset)
|
||||
tail = read_range(path, tail_offset, tail_length)
|
||||
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
|
||||
moov_box = next(
|
||||
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
|
||||
)
|
||||
if moov_box is None:
|
||||
return None
|
||||
ftyp_box = _one(top_boxes, b"ftyp", required=False)
|
||||
ftyp = (
|
||||
prefix[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
moov_start = moov_box.payload_start - tail_offset
|
||||
moov_end = moov_box.end - tail_offset
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=tail[moov_start:moov_end],
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
|
||||
if file_size is None:
|
||||
file_size = len(data)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
ftyp_box = _one(top, b"ftyp", required=False)
|
||||
moov_box = _one(top, b"moov")
|
||||
mdat_box = _one(top, b"mdat")
|
||||
if moov_box.end > len(data):
|
||||
raise ValueError(f"{path}: moov box is truncated")
|
||||
|
||||
moov = data[moov_box.payload_start : moov_box.end]
|
||||
ftyp = (
|
||||
data[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=moov,
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def _parse_mp4_index_from_layout(
|
||||
path: str,
|
||||
*,
|
||||
file_size: int,
|
||||
ftyp: bytes,
|
||||
moov_offset: int,
|
||||
moov: bytes,
|
||||
mdat_box: Box,
|
||||
) -> Mp4Index:
|
||||
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
|
||||
trak_box, trak_payload = _find_video_trak(moov)
|
||||
_ = trak_box
|
||||
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
|
||||
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
|
||||
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
|
||||
|
||||
stsd = _find_child(stbl, b"stsd")
|
||||
stsd_body = stbl[stsd.payload_start : stsd.end]
|
||||
codec = _parse_stsd_codec(stsd_body)
|
||||
stts = _parse_stts(_payload(stbl, b"stts"))
|
||||
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_payload(stbl, b"stsc"))
|
||||
chunk_offsets = _parse_chunk_offsets(stbl)
|
||||
sync_samples = _parse_stss(stbl, len(sample_sizes))
|
||||
|
||||
sample_durations = _expand_stts(stts, len(sample_sizes))
|
||||
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
|
||||
if len(sample_durations):
|
||||
sample_pts_units[0] = 0
|
||||
if len(sample_durations) > 1:
|
||||
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
|
||||
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
|
||||
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
|
||||
|
||||
return Mp4Index(
|
||||
file_path=path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_offset,
|
||||
mdat_offset=mdat_box.start,
|
||||
mdat_payload_offset=mdat_box.payload_start,
|
||||
mdat_payload_size=mdat_box.end - mdat_box.payload_start
|
||||
if mdat_box.end <= file_size
|
||||
else file_size - mdat_box.payload_start,
|
||||
faststart=moov_offset < mdat_box.start,
|
||||
codec=codec,
|
||||
timescale=mdhd_timescale,
|
||||
duration=mdhd_duration or mvhd_duration,
|
||||
track_id=tkhd["track_id"],
|
||||
width=tkhd["width"],
|
||||
height=tkhd["height"],
|
||||
stsd_body=stsd_body,
|
||||
sample_pts=sample_pts,
|
||||
sample_durations=sample_durations,
|
||||
sample_sizes=sample_sizes,
|
||||
sample_offsets=sample_offsets,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
|
||||
lo = sample_slice.sample_lo
|
||||
hi = sample_slice.sample_hi + 1
|
||||
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
|
||||
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
|
||||
|
||||
offsets = index.sample_offsets[lo:hi]
|
||||
sizes = index.sample_sizes[lo:hi]
|
||||
rel_offsets = offsets - sample_slice.byte_offset
|
||||
if int(rel_offsets.min()) != 0:
|
||||
raise ValueError("Sample slice must start at the minimum referenced sample offset")
|
||||
if int((rel_offsets + sizes).max()) > len(mdat_payload):
|
||||
raise ValueError("Sample slice does not cover all referenced samples")
|
||||
|
||||
durations = index.sample_durations[lo:hi]
|
||||
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
|
||||
header_size = len(index.ftyp) + len(moov)
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
|
||||
return index.ftyp + moov + _box(b"mdat", mdat_payload)
|
||||
|
||||
|
||||
def iter_boxes(
|
||||
data: bytes,
|
||||
start: int,
|
||||
end: int,
|
||||
*,
|
||||
absolute_base: int = 0,
|
||||
allow_truncated: bool = False,
|
||||
) -> Iterable[Box]:
|
||||
pos = start
|
||||
while pos + 8 <= end:
|
||||
size = struct.unpack_from(">I", data, pos)[0]
|
||||
typ = data[pos + 4 : pos + 8]
|
||||
header_size = 8
|
||||
if size == 1:
|
||||
if pos + 16 > end:
|
||||
break
|
||||
size = struct.unpack_from(">Q", data, pos + 8)[0]
|
||||
header_size = 16
|
||||
elif size == 0:
|
||||
size = end - pos
|
||||
if size < header_size:
|
||||
break
|
||||
box_end = pos + size
|
||||
if box_end > end and not allow_truncated:
|
||||
break
|
||||
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
|
||||
pos = box_end
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
|
||||
for trak in _children(moov, 0, len(moov)):
|
||||
if trak.type != b"trak":
|
||||
continue
|
||||
payload = moov[trak.payload_start : trak.end]
|
||||
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
|
||||
if hdlr[8:12] == b"vide":
|
||||
return trak, payload
|
||||
raise ValueError("No video track found")
|
||||
|
||||
|
||||
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
|
||||
current = data
|
||||
for typ in path:
|
||||
box = _find_child(current, typ)
|
||||
current = current[box.payload_start : box.end]
|
||||
return current
|
||||
|
||||
|
||||
def _find_child(data: bytes, typ: bytes) -> Box:
|
||||
for box in _children(data, 0, len(data)):
|
||||
if box.type == typ:
|
||||
return box
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
|
||||
|
||||
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
|
||||
return iter_boxes(data, start, end, absolute_base=0)
|
||||
|
||||
|
||||
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
|
||||
matches = [box for box in boxes if box.type == typ]
|
||||
if not matches and required:
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
def _payload(parent: bytes, typ: bytes) -> bytes:
|
||||
box = _find_child(parent, typ)
|
||||
return parent[box.payload_start : box.end]
|
||||
|
||||
|
||||
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_tkhd(payload: bytes) -> dict[str, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
track_id = struct.unpack_from(">I", payload, 20)[0]
|
||||
duration = struct.unpack_from(">Q", payload, 28)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 88)
|
||||
else:
|
||||
track_id = struct.unpack_from(">I", payload, 12)[0]
|
||||
duration = struct.unpack_from(">I", payload, 20)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 76)
|
||||
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
|
||||
|
||||
|
||||
def _parse_stsd_codec(stsd_body: bytes) -> str:
|
||||
if len(stsd_body) < 16:
|
||||
return "unknown"
|
||||
return stsd_body[12:16].decode("latin1")
|
||||
|
||||
|
||||
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">II", payload, offset))
|
||||
offset += 8
|
||||
return out
|
||||
|
||||
|
||||
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
pos = 0
|
||||
for count, delta in entries:
|
||||
values[pos : pos + count] = delta
|
||||
pos += count
|
||||
if pos != sample_count:
|
||||
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsz(payload: bytes) -> np.ndarray:
|
||||
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
|
||||
if sample_size:
|
||||
return np.full(sample_count, sample_size, dtype=np.int64)
|
||||
offset = 12
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
for idx in range(sample_count):
|
||||
values[idx] = struct.unpack_from(">I", payload, offset)[0]
|
||||
offset += 4
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">III", payload, offset))
|
||||
offset += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
|
||||
with_stco = None
|
||||
with_co64 = None
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stco":
|
||||
with_stco = stbl[box.payload_start : box.end]
|
||||
elif box.type == b"co64":
|
||||
with_co64 = stbl[box.payload_start : box.end]
|
||||
if with_co64 is not None:
|
||||
count = struct.unpack_from(">I", with_co64, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
if with_stco is None:
|
||||
raise ValueError("Missing stco/co64 chunk offsets")
|
||||
count = struct.unpack_from(">I", with_stco, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
|
||||
|
||||
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stss":
|
||||
payload = stbl[box.payload_start : box.end]
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
|
||||
dtype=np.int64,
|
||||
)
|
||||
return np.arange(sample_count, dtype=np.int64)
|
||||
|
||||
|
||||
def _sample_offsets(
|
||||
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
|
||||
) -> np.ndarray:
|
||||
if not stsc:
|
||||
raise ValueError("stsc is empty")
|
||||
offsets = np.empty(len(sample_sizes), dtype=np.int64)
|
||||
sample_idx = 0
|
||||
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
|
||||
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
|
||||
for chunk_number in range(first_chunk, next_first):
|
||||
if chunk_number < 1 or chunk_number > len(chunk_offsets):
|
||||
raise ValueError("stsc references a chunk outside stco/co64")
|
||||
chunk_pos = int(chunk_offsets[chunk_number - 1])
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(sample_sizes):
|
||||
return offsets
|
||||
offsets[sample_idx] = chunk_pos
|
||||
chunk_pos += int(sample_sizes[sample_idx])
|
||||
sample_idx += 1
|
||||
if sample_idx != len(sample_sizes):
|
||||
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
|
||||
return offsets
|
||||
|
||||
|
||||
def _make_moov(
|
||||
index: Mp4Index,
|
||||
durations: np.ndarray,
|
||||
sizes: np.ndarray,
|
||||
rel_offsets: np.ndarray,
|
||||
sync_samples: np.ndarray,
|
||||
*,
|
||||
mdat_data_offset: int,
|
||||
) -> bytes:
|
||||
duration = int(durations.sum())
|
||||
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
|
||||
if any(value > 0xFFFFFFFF for value in stco_values):
|
||||
offset_box = _co64(stco_values)
|
||||
else:
|
||||
offset_box = _stco(stco_values)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", index.stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offset_box
|
||||
+ (_stss(sync_samples) if len(sync_samples) else b""),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
|
||||
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
|
||||
|
||||
|
||||
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
|
||||
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
|
||||
|
||||
|
||||
def _box(typ: bytes, payload: bytes) -> bytes:
|
||||
size = len(payload) + 8
|
||||
if size <= 0xFFFFFFFF:
|
||||
return struct.pack(">I4s", size, typ) + payload
|
||||
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
|
||||
|
||||
|
||||
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIII", 0, 0, timescale, duration)
|
||||
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
|
||||
+ b"\0" * 8
|
||||
+ matrix
|
||||
+ b"\0" * 24
|
||||
+ struct.pack(">I", next_track_id)
|
||||
)
|
||||
return _full_box(b"mvhd", 0, 0, payload)
|
||||
|
||||
|
||||
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
|
||||
+ b"\0" * 8
|
||||
+ struct.pack(">hhhh", 0, 0, 0, 0)
|
||||
+ matrix
|
||||
+ struct.pack(">II", width << 16, height << 16)
|
||||
)
|
||||
return _full_box(b"tkhd", 0, 7, payload)
|
||||
|
||||
|
||||
def _mdhd(timescale: int, duration: int) -> bytes:
|
||||
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
|
||||
|
||||
|
||||
def _hdlr() -> bytes:
|
||||
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
|
||||
|
||||
|
||||
def _vmhd() -> bytes:
|
||||
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
|
||||
|
||||
|
||||
def _dinf() -> bytes:
|
||||
url = _full_box(b"url ", 0, 1)
|
||||
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
|
||||
return _box(b"dinf", dref)
|
||||
|
||||
|
||||
def _stts(durations: np.ndarray) -> bytes:
|
||||
runs = []
|
||||
for duration in durations.tolist():
|
||||
if runs and runs[-1][1] == int(duration):
|
||||
runs[-1][0] += 1
|
||||
else:
|
||||
runs.append([1, int(duration)])
|
||||
payload = struct.pack(">I", len(runs)) + b"".join(
|
||||
struct.pack(">II", count, delta) for count, delta in runs
|
||||
)
|
||||
return _full_box(b"stts", 0, 0, payload)
|
||||
|
||||
|
||||
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
|
||||
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
|
||||
|
||||
|
||||
def _stsz(sizes: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stsz",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
|
||||
)
|
||||
|
||||
|
||||
def _stco(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _co64(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _stss(values: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stss",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
|
||||
)
|
||||
@@ -24,6 +24,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +32,22 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_pix_fmt_channels(pix_fmt: str) -> int:
|
||||
"""Return the number of components (channels) for *pix_fmt*."""
|
||||
return len(av.VideoFormat(pix_fmt).components)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -92,7 +109,7 @@ def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Opti
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
num_val = float(value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
@@ -142,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
|
||||
"""Ensure *pix_fmt* can carry at least *channels* components."""
|
||||
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
|
||||
if pix_fmt_channels < channels:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
|
||||
f"but the source data has {channels} channel(s)."
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
@@ -156,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
def check_video_encoder_parameters_pyav(
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, Any],
|
||||
channels: int | None = None,
|
||||
) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
@@ -171,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
if channels is not None:
|
||||
_check_pix_fmt_channels(pix_fmt, channels)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
|
||||
@@ -87,11 +87,14 @@ DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
IMAGE_FILE_PATTERN = "frame-{frame_index:06d}.png"
|
||||
DEPTH_FILE_PATTERN = "frame-{frame_index:06d}.tiff"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/" + IMAGE_FILE_PATTERN
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/" + DEPTH_FILE_PATTERN
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -39,11 +39,16 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,6 +58,7 @@ def decode_video_frames(
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -64,23 +70,35 @@ def decode_video_frames(
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||
torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
|
||||
)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -91,6 +109,7 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
@@ -109,8 +128,9 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
|
||||
Otherwise, return float32 in [0, 1] range.
|
||||
is_depth: Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
@@ -140,9 +160,13 @@ def decode_video_frames_pyav(
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
@@ -185,7 +209,7 @@ def decode_video_frames_pyav(
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
if return_uint8 or is_depth:
|
||||
return closest_frames
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
@@ -406,17 +430,38 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
"""Encode a directory of image frames into an MP4 video.
|
||||
|
||||
When ``video_encoder`` is a :class:`~lerobot.configs.video.DepthEncoderConfig`,
|
||||
frames are read from ``.tiff`` files and quantized to 12-bit depth codes using the
|
||||
encoder's ``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``; otherwise ``.png``
|
||||
RGB frames are encoded directly.
|
||||
|
||||
Args:
|
||||
imgs_dir: Directory containing the frames to encode, named ``frame-000000``
|
||||
onwards (``.png`` for RGB, ``.tiff`` for depth).
|
||||
video_path: Output path for the encoded ``.mp4`` file.
|
||||
fps: Frame rate of the output video.
|
||||
video_encoder: Encoder settings (codec, pixel format, quality, ...). When
|
||||
``None``, :func:`camera_encoder_defaults` is used. Pass a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig` to encode depth frames.
|
||||
encoder_threads: Per-encoder thread count forwarded to the codec. ``None``
|
||||
lets the codec decide.
|
||||
log_level: libav log level to set while encoding, or ``None`` to leave the
|
||||
current logging configuration unchanged.
|
||||
overwrite: When ``False`` and ``video_path`` already exists, skip encoding and
|
||||
log a warning. When ``True``, re-encode and replace the existing file.
|
||||
"""
|
||||
if video_encoder is None:
|
||||
video_encoder = camera_encoder_defaults()
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -428,17 +473,19 @@ def encode_video_frames(
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
suffix = ".png" if not is_depth else ".tiff"
|
||||
template = "frame-" + ("[0-9]" * 6) + suffix
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
raise FileNotFoundError(f"No images with suffix {suffix} found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -455,8 +502,19 @@ def encode_video_frames(
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
with Image.open(input_data) as input_image:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
if is_depth:
|
||||
input_frame = quantize_depth(
|
||||
np.array(input_image),
|
||||
depth_min=video_encoder.depth_min,
|
||||
depth_max=video_encoder.depth_max,
|
||||
shift=video_encoder.shift,
|
||||
use_log=video_encoder.use_log,
|
||||
pix_fmt=video_encoder.pix_fmt,
|
||||
video_backend="pyav",
|
||||
)
|
||||
else:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
@@ -477,7 +535,7 @@ def encode_video_frames(
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
@@ -489,7 +547,7 @@ def reencode_video(
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
video_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
@@ -497,7 +555,7 @@ def reencode_video(
|
||||
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
video_encoder = video_encoder or camera_encoder_defaults()
|
||||
|
||||
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
|
||||
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
|
||||
@@ -512,9 +570,9 @@ def reencode_video(
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
@@ -696,22 +754,21 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
video_encoder: VideoEncoderConfig,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.video_encoder = video_encoder
|
||||
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -736,12 +793,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -749,15 +806,29 @@ class _CameraEncoderThread(threading.Thread):
|
||||
height, width = frame_data.shape[:2]
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream = container.add_stream(
|
||||
self.video_encoder.vcodec,
|
||||
self.fps,
|
||||
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||
)
|
||||
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
if not self.is_depth:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
else:
|
||||
video_frame = quantize_depth(
|
||||
frame_data,
|
||||
depth_min=self.video_encoder.depth_min,
|
||||
depth_max=self.video_encoder.depth_max,
|
||||
shift=self.video_encoder.shift,
|
||||
use_log=self.video_encoder.use_log,
|
||||
video_backend=self.video_encoder.video_backend,
|
||||
)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
@@ -816,21 +887,26 @@ class StreamingVideoEncoder:
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
camera_encoder: Video encoder settings applied to all RGB cameras.
|
||||
When ``None``, :func:`camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
depth_encoder: Video encoder settings applied to all depth cameras,
|
||||
including the depth quantization parameters. When ``None``,
|
||||
:func:`depth_encoder_defaults` is used.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
|
||||
@@ -843,18 +919,25 @@ class StreamingVideoEncoder:
|
||||
self._episode_active = False
|
||||
self._closed = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
def start_episode(
|
||||
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
|
||||
) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
depth_video_keys: List of video or image feature keys that carry depth maps (e.g.
|
||||
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
if depth_video_keys is None:
|
||||
depth_video_keys = []
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
@@ -863,17 +946,15 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
video_encoder=encoder,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self._encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1080,15 +1161,23 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
not already populated from the video file itself. When a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig` is passed, the depth
|
||||
quantization parameters (``depth_min`` / ``depth_max`` / ``shift`` /
|
||||
``use_log``) are recorded so frames can be dequantized on read.
|
||||
|
||||
Returns:
|
||||
The ``video.*`` / ``audio.*`` info dict, including ``is_depth_map`` which is
|
||||
``True`` only when ``video_encoder`` is a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig`.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
@@ -1106,13 +1195,10 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
||||
video_info["video.channels"] = pixel_channels
|
||||
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
@@ -1121,27 +1207,18 @@ def get_video_info(
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
if video_encoder is not None:
|
||||
for field_name, field_value in asdict(video_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
@@ -1202,10 +1279,13 @@ class VideoEncodingManager:
|
||||
img_dir = self.dataset.root / "images"
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
tiff_files = list(img_dir.rglob("*.tiff"))
|
||||
if len(png_files) == 0 and len(tiff_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
logger.debug(
|
||||
f"Images directory is not empty, containing {len(png_files)} PNG and {len(tiff_files)} TIFF files"
|
||||
)
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -126,7 +126,8 @@ def prepare_observation_for_inference(
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
if observation[name].dtype == torch.uint8:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
@@ -68,9 +68,12 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
|
||||
if getattr(self.cameras[cam], "use_depth", False):
|
||||
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
|
||||
return features
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -190,6 +193,12 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
if getattr(cam, "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -333,6 +333,7 @@ def build_rollout_context(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
@@ -368,6 +369,7 @@ def build_rollout_context(
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
|
||||
@@ -133,6 +133,15 @@ Convert image dataset to video format and save locally:
|
||||
--new_root /path/to/output/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset (with depth maps) to video format, customizing the depth encoder:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_root /path/to/output/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.depth_encoder.depth_min 0.01 \
|
||||
--operation.depth_encoder.depth_max 10.0 \
|
||||
--operation.depth_encoder.use_log true
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -211,6 +220,13 @@ Re-encode videos in-place (overwrites original dataset):
|
||||
--operation.camera_encoder.vcodec h264 \
|
||||
--operation.overwrite true
|
||||
|
||||
Re-encode both RGB and depth videos in a dataset (depth quantization params are preserved):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_depth \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec libx264 \
|
||||
--operation.depth_encoder.vcodec ffv1
|
||||
|
||||
Using JSON config file:
|
||||
lerobot-edit-dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
@@ -225,7 +241,13 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, parser
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
parser,
|
||||
)
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
convert_image_to_video_dataset,
|
||||
@@ -288,6 +310,7 @@ class ModifyTasksConfig(OperationConfig):
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
output_dir: str | None = None
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
@@ -309,6 +332,7 @@ class RecomputeStatsConfig(OperationConfig):
|
||||
@dataclass
|
||||
class ReencodeVideosConfig(OperationConfig):
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
num_workers: int = 0
|
||||
encoder_threads: int | None = None
|
||||
overwrite: bool = False
|
||||
@@ -602,6 +626,7 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder=getattr(cfg.operation, "camera_encoder", None) or camera_encoder_defaults(),
|
||||
depth_encoder=getattr(cfg.operation, "depth_encoder", None) or depth_encoder_defaults(),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
@@ -719,10 +744,14 @@ def handle_reencode_videos(cfg: EditDatasetConfig) -> None:
|
||||
shutil.copytree(input_root, output_root)
|
||||
dataset = LeRobotDataset(output_repo_id, root=output_root)
|
||||
|
||||
logging.info(f"Re-encoding videos in {output_repo_id} with {cfg.operation.camera_encoder}")
|
||||
logging.info(
|
||||
f"Re-encoding videos in {output_repo_id} with RGB encoder {cfg.operation.camera_encoder} "
|
||||
f"and depth encoder {cfg.operation.depth_encoder}"
|
||||
)
|
||||
reencode_dataset(
|
||||
dataset,
|
||||
camera_encoder=cfg.operation.camera_encoder,
|
||||
depth_encoder=cfg.operation.depth_encoder,
|
||||
encoder_threads=cfg.operation.encoder_threads,
|
||||
num_workers=cfg.operation.num_workers,
|
||||
)
|
||||
|
||||
@@ -404,6 +404,7 @@ def record(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
@@ -433,6 +434,7 @@ def record(
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
|
||||
@@ -51,7 +51,9 @@ def hw_to_dataset_features(
|
||||
|
||||
This function takes a dictionary describing hardware outputs (like joint states
|
||||
or camera image shapes) and formats it into the standard LeRobot feature
|
||||
specification.
|
||||
specification. Single-channel cameras (shape ``(H, W, 1)``) are flagged as depth
|
||||
maps via ``info["is_depth_map"] = True``; three-channel cameras ``(H, W, 3)`` are
|
||||
treated as RGB.
|
||||
|
||||
Args:
|
||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||
@@ -61,7 +63,7 @@ def hw_to_dataset_features(
|
||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||
|
||||
Returns:
|
||||
dict: A LeRobot features dictionary.
|
||||
dict: A LeRobot features dictionary. Depth cameras carry ``info["is_depth_map"] = True``.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {
|
||||
@@ -69,6 +71,7 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
# TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
@@ -86,11 +89,19 @@ def hw_to_dataset_features(
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
dtype = "video" if use_video else "image"
|
||||
if len(shape) == 3 and shape[2] in (1, 3):
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": shape[2] == 1},
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Camera feature '{key}' has shape {shape}. "
|
||||
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
|
||||
)
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -149,11 +160,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
else:
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
|
||||
@@ -107,7 +107,14 @@ def log_rerun_data(
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
if arr.shape[-1] == 1:
|
||||
img_entity = (
|
||||
rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis).compress()
|
||||
if compress_images
|
||||
else rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
|
||||
)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
|
||||
if action:
|
||||
|
||||
@@ -208,14 +208,14 @@ def test_episode_clip_path_trims_via_reencode_video(tmp_path: Path, monkeypatch)
|
||||
def fake_reencode(
|
||||
input_video_path,
|
||||
output_video_path,
|
||||
camera_encoder=None,
|
||||
video_encoder=None,
|
||||
overwrite=False,
|
||||
start_time_s=None,
|
||||
end_time_s=None,
|
||||
):
|
||||
captured.update(
|
||||
src=Path(input_video_path),
|
||||
encoder=camera_encoder,
|
||||
encoder=video_encoder,
|
||||
start_time_s=start_time_s,
|
||||
end_time_s=end_time_s,
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
|
||||
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
|
||||
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
|
||||
for ep, length in enumerate(episode_lengths):
|
||||
episode_index += [ep] * length
|
||||
frame_index += list(range(length))
|
||||
timestamp += [round(i / fps, 6) for i in range(length)]
|
||||
task_index += [0] * length
|
||||
subtask_index += [0] * length # legacy column the writer must drop
|
||||
pd.DataFrame(
|
||||
{
|
||||
"episode_index": episode_index,
|
||||
"frame_index": frame_index,
|
||||
"timestamp": timestamp,
|
||||
"task_index": task_index,
|
||||
"subtask_index": subtask_index,
|
||||
}
|
||||
).to_parquet(data_dir / "file-000.parquet", index=False)
|
||||
|
||||
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
|
||||
write_tasks(tasks_df, root)
|
||||
write_json(
|
||||
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
|
||||
"""Rewriting a packed shard must keep one row group per episode, not collapse
|
||||
every episode into a single giant row group."""
|
||||
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
|
||||
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
|
||||
shard = root / "data" / "chunk-000" / "file-000.parquet"
|
||||
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
for ep in range(len(episode_lengths)):
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
ep,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"subtask for ep {ep}",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
records = list(iter_episodes(root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, root)
|
||||
|
||||
# One row group per episode, with row counts matching the episode lengths.
|
||||
md = pq.ParquetFile(shard).metadata
|
||||
assert md.num_row_groups == len(episode_lengths)
|
||||
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
|
||||
# Language columns are still present after the per-episode rewrite.
|
||||
table = pq.read_table(shard)
|
||||
assert "language_persistent" in table.column_names
|
||||
assert "language_events" in table.column_names
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
|
||||
@@ -29,7 +29,30 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_data_shards_one_row_group_per_episode(root):
|
||||
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
shards = sorted((root / "data").rglob("*.parquet"))
|
||||
assert shards, f"no data shards found under {root}/data"
|
||||
n_episodes = 0
|
||||
for shard in shards:
|
||||
pf = pq.ParquetFile(shard)
|
||||
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
assert pf.metadata.num_row_groups == len(set(episodes)), shard
|
||||
for i in range(pf.metadata.num_row_groups):
|
||||
rg_episodes = set(
|
||||
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
)
|
||||
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
|
||||
n_episodes += len(set(episodes))
|
||||
return n_episodes
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
@@ -191,6 +214,26 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
pass
|
||||
|
||||
|
||||
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
|
||||
"""Test that depth keys are correctly preserved after aggregation.
|
||||
|
||||
Ensures that the ``is_depth_map`` marker on visual features survives
|
||||
aggregation, so that downstream consumers (e.g. the dataset reader's
|
||||
depth decoding path) keep working on the merged dataset.
|
||||
"""
|
||||
expected_depth_keys = set(ds_0.meta.depth_keys)
|
||||
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
|
||||
"Source datasets disagree on depth_keys; test setup is inconsistent"
|
||||
)
|
||||
actual_depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
assert actual_depth_keys == expected_depth_keys, (
|
||||
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
|
||||
)
|
||||
for key in expected_depth_keys:
|
||||
info = aggr_ds.meta.info.features[key].get("info") or {}
|
||||
assert info.get("is_depth_map") is True, f"Depth marker lost on feature {key!r} after aggregation"
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||
|
||||
@@ -240,7 +283,11 @@ def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
"""Test basic aggregation functionality with standard parameters.
|
||||
|
||||
Source datasets include both RGB and depth video features so the same
|
||||
aggregation flow is exercised on the ``is_depth_map`` branch.
|
||||
"""
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 800
|
||||
ds_0_num_episodes = 10
|
||||
@@ -252,14 +299,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Confirm depth was actually wired into the source datasets so the
|
||||
# rest of the assertions exercise the depth aggregation path.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
@@ -286,6 +340,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@@ -403,7 +458,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding.
|
||||
|
||||
Depth video features are included to verify that file rotation/concat
|
||||
correctly handles depth-marked features alongside regular RGB ones.
|
||||
"""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
ds_0_num_frames = ds_1_num_frames = 400
|
||||
|
||||
@@ -412,14 +471,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Use the new configurable parameters to force file rotation
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
@@ -450,6 +514,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
@@ -469,7 +534,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for video timestamp bug when merging datasets.
|
||||
|
||||
This test specifically checks that video timestamps are correctly calculated
|
||||
and accumulated when merging multiple datasets.
|
||||
and accumulated when merging multiple datasets. Depth video features are
|
||||
included so depth timestamps are also covered by the regression.
|
||||
"""
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
@@ -478,9 +544,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
for i, ds in enumerate(datasets):
|
||||
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds.repo_id for ds in datasets],
|
||||
roots=[ds.root for ds in datasets],
|
||||
@@ -497,12 +567,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
# Depth keys must survive the merge for the regression to cover the
|
||||
# ``is_depth_map`` decoding branch.
|
||||
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
|
||||
|
||||
depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
for i in range(len(aggr_ds)):
|
||||
item = aggr_ds[i]
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
# Depth frames are single-channel (1, H, W) after dequantization;
|
||||
# standard RGB frames keep the 3-channel layout.
|
||||
expected_channels = 1 if key in depth_keys else 3
|
||||
assert item[key].shape[0] == expected_channels, (
|
||||
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
|
||||
)
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
@@ -566,6 +645,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
|
||||
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
|
||||
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
|
||||
|
||||
Covers both the non-image (``df.to_parquet``) and image
|
||||
(``to_parquet_with_hf_images``) write branches, including the merge-into-
|
||||
existing-file branch via a low file-size threshold that forces packing.
|
||||
"""
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_0",
|
||||
total_episodes=3,
|
||||
total_frames=60,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_1",
|
||||
total_episodes=4,
|
||||
total_frames=80,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
|
||||
aggr_root = tmp_path / "rg_aggr"
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
|
||||
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
|
||||
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
@@ -584,25 +698,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
# Create two image-based datasets (use_videos=False) with a mix of RGB
|
||||
# and depth-marked cameras so the depth path is exercised in image mode.
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
# And that the depth marker actually made it onto an image feature.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
@@ -637,6 +757,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
|
||||
@@ -59,11 +59,13 @@ def _make_dummy_stats(features: dict) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
channels = ft["shape"][-1]
|
||||
stat_shape = (channels, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.ones((3, 1, 1), dtype=np.float32),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
|
||||
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
|
||||
"max": np.ones(stat_shape, dtype=np.float32),
|
||||
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
|
||||
"min": np.zeros(stat_shape, dtype=np.float32),
|
||||
"std": np.full(stat_shape, 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
elif ft["dtype"] in ("float32", "float64", "int64"):
|
||||
@@ -142,6 +144,45 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("marker_field", "marker_key"),
|
||||
[
|
||||
("info", "is_depth_map"),
|
||||
("info", "video.is_depth_map"),
|
||||
("video_info", "video.is_depth_map"),
|
||||
],
|
||||
ids=["info.is_depth_map", "info.video.is_depth_map_legacy", "video_info.video.is_depth_map_legacy"],
|
||||
)
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path, marker_field, marker_key):
|
||||
"""``depth_keys`` recognises the canonical and the two legacy marker variants."""
|
||||
depth_feature = {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
marker_field: {marker_key: True},
|
||||
}
|
||||
features = {
|
||||
**VIDEO_FEATURES,
|
||||
"observation.images.laptop_depth": depth_feature,
|
||||
}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/depth_keys",
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / f"depth_keys_{marker_field}_{marker_key.replace('.', '_')}",
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
|
||||
assert meta.depth_keys == ["observation.images.laptop_depth"]
|
||||
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
|
||||
)
|
||||
assert meta.depth_keys == []
|
||||
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
convert_image_to_video_dataset,
|
||||
@@ -37,7 +37,9 @@ from lerobot.datasets.dataset_tools import (
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.io_utils import load_info
|
||||
from tests.datasets.test_video_encoding import _add_frames, require_h264, require_libsvtav1
|
||||
from tests.datasets.test_video_encoding import require_h264, require_hevc, require_libsvtav1
|
||||
from tests.fixtures.constants import DUMMY_DEPTH_FEATURES, DUMMY_DEPTH_KEY
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1332,9 +1334,131 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
@require_libsvtav1
|
||||
@require_hevc
|
||||
def test_convert_image_to_video_dataset_depth(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Depth image features convert to depth videos using the depth encoder.
|
||||
|
||||
Mirrors :func:`test_convert_image_to_video_dataset` but with a small local
|
||||
image dataset that mixes an RGB camera with a depth camera, so the
|
||||
``depth_keys`` → ``depth_encoder`` routing and ``is_depth_map`` preservation
|
||||
are exercised end-to-end.
|
||||
"""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
"observation.images.cam": {
|
||||
"dtype": "image",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.depth": {
|
||||
"dtype": "image",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": True},
|
||||
},
|
||||
}
|
||||
source_dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "img_ds",
|
||||
features=features,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
add_frames(source_dataset, num_frames=4)
|
||||
source_dataset.save_episode()
|
||||
source_dataset.finalize()
|
||||
|
||||
# Source is an image dataset with the depth marker on the depth camera.
|
||||
assert len(source_dataset.meta.video_keys) == 0
|
||||
assert "observation.images.depth" in source_dataset.meta.depth_keys
|
||||
|
||||
output_dir = tmp_path / "video_ds"
|
||||
with (
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
# Use non-default quantization params so the persisted metadata must
|
||||
# come from the depth encoder (not RGB encoder defaults).
|
||||
depth_encoder = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=2,
|
||||
crf=30,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.0,
|
||||
use_log=False,
|
||||
)
|
||||
video_dataset = convert_image_to_video_dataset(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="dummy/depth_video",
|
||||
camera_encoder=VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30),
|
||||
depth_encoder=depth_encoder,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
# Both cameras are now videos, and the depth marker survived the conversion.
|
||||
assert "observation.images.cam" in video_dataset.meta.video_keys
|
||||
assert "observation.images.depth" in video_dataset.meta.video_keys
|
||||
assert "observation.images.depth" in video_dataset.meta.depth_keys
|
||||
assert "observation.images.cam" not in video_dataset.meta.depth_keys
|
||||
|
||||
depth_path = video_dataset.root / video_dataset.meta.get_video_file_path(0, "observation.images.depth")
|
||||
assert depth_path.exists(), f"Depth video file should exist: {depth_path}"
|
||||
|
||||
# The persisted depth-video metadata must carry the depth quantization params
|
||||
# from the depth encoder (so frames dequantize correctly on read), and the RGB
|
||||
# camera must not be marked as a depth map.
|
||||
persisted_info = load_info(video_dataset.root)
|
||||
depth_info = persisted_info.features["observation.images.depth"]["info"]
|
||||
assert depth_info["is_depth_map"] is True
|
||||
assert DepthEncoderConfig.from_video_info(depth_info) == depth_encoder
|
||||
|
||||
cam_info = persisted_info.features["observation.images.cam"]["info"]
|
||||
assert cam_info.get("is_depth_map") is False
|
||||
assert "video.codec" in cam_info
|
||||
|
||||
|
||||
# ─── reencode_dataset ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@require_hevc
|
||||
def test_reencode_dataset_depth_uses_depth_encoder(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Depth videos are re-encoded with the depth encoder and keep their depth metadata.
|
||||
|
||||
Depth-focused companion to :func:`test_reencode_dataset_multi_key_multiprocessing`.
|
||||
"""
|
||||
initial_cfg = DepthEncoderConfig(vcodec="hevc", pix_fmt="gray12le", g=2, crf=30)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
features=DUMMY_DEPTH_FEATURES,
|
||||
use_videos=True,
|
||||
depth_encoder=initial_cfg,
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert DUMMY_DEPTH_KEY in dataset.meta.depth_keys
|
||||
|
||||
target_cfg = DepthEncoderConfig(vcodec="hevc", pix_fmt="gray12le", g=6, crf=23)
|
||||
result = reencode_dataset(dataset, depth_encoder=target_cfg, num_workers=0)
|
||||
|
||||
assert result is dataset
|
||||
|
||||
persisted_info = load_info(dataset.root)
|
||||
depth_info = persisted_info.features[DUMMY_DEPTH_KEY].get("info", {})
|
||||
# Re-encode applied the new codec parameters to the depth video ...
|
||||
assert DepthEncoderConfig.from_video_info(depth_info) == target_cfg
|
||||
# ... while preserving the depth marker.
|
||||
assert depth_info["is_depth_map"] is True
|
||||
|
||||
|
||||
@require_libsvtav1
|
||||
@require_h264
|
||||
def test_reencode_dataset_multi_key_multiprocessing(
|
||||
@@ -1350,9 +1474,9 @@ def test_reencode_dataset_multi_key_multiprocessing(
|
||||
camera_encoder=initial_cfg,
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_camera_encoder(tmp_path):
|
||||
"""_encode_video_worker forwards camera_encoder to encode_video_frames."""
|
||||
def test_encode_video_worker_forwards_video_encoder(tmp_path):
|
||||
"""_encode_video_worker forwards video_encoder to encode_video_frames."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -74,16 +74,16 @@ def test_encode_video_worker_forwards_camera_encoder(tmp_path):
|
||||
0,
|
||||
tmp_path,
|
||||
fps=30,
|
||||
camera_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
video_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
encoder_threads=4,
|
||||
)
|
||||
|
||||
assert captured_kwargs["camera_encoder"].vcodec == "h264"
|
||||
assert captured_kwargs["video_encoder"].vcodec == "h264"
|
||||
assert captured_kwargs["encoder_threads"] == 4
|
||||
|
||||
|
||||
def test_encode_video_worker_default_camera_encoder(tmp_path):
|
||||
"""_encode_video_worker passes None camera_encoder which encode_video_frames defaults."""
|
||||
def test_encode_video_worker_default_video_encoder(tmp_path):
|
||||
"""_encode_video_worker passes None video_encoder which encode_video_frames defaults."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -100,7 +100,7 @@ def test_encode_video_worker_default_camera_encoder(tmp_path):
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||
|
||||
assert captured_kwargs["camera_encoder"] is None
|
||||
assert captured_kwargs["video_encoder"] is None
|
||||
assert captured_kwargs["encoder_threads"] is None
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||
)
|
||||
|
||||
|
||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
||||
)
|
||||
dataset_copy = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
||||
)
|
||||
|
||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
||||
|
||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
@@ -1516,10 +1531,15 @@ def test_valid_video_codecs_constant():
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert "h264_vaapi" in VALID_VIDEO_CODECS
|
||||
assert "h264_qsv" in VALID_VIDEO_CODECS
|
||||
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "hevc_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Tests for the depth-integration feature.
|
||||
|
||||
Covers:
|
||||
- ``depth_utils`` quantize/dequantize round-trips and backend agreement.
|
||||
- Image-writer support for single-channel depth.
|
||||
- Hardware-feature → depth flag routing.
|
||||
- Feature-to-file-format routing through the dataset writer.
|
||||
|
||||
Depth metadata detection on ``LeRobotDatasetMetadata.depth_keys`` lives in
|
||||
``test_dataset_metadata.py``. Depth video encoding/decoding lives in
|
||||
``test_video_encoding.py``.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig
|
||||
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
|
||||
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_CHW,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
_, H, W = DUMMY_CHW
|
||||
|
||||
|
||||
def _depth_metres_ramp() -> np.ndarray:
|
||||
"""Linearly-spaced float32 depth in metres covering the default range."""
|
||||
return np.linspace(DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
|
||||
|
||||
|
||||
# ── 1. Quantize / dequantize round-trips ──────────────────────────────
|
||||
|
||||
|
||||
class TestQuantizeDequantize:
|
||||
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
@pytest.mark.parametrize("output_channel_last", [False, True])
|
||||
def test_roundtrip(self, use_log, output_unit, output_channel_last):
|
||||
"""quantize → dequantize recovers depth; layout and unit are honored."""
|
||||
depth = _depth_metres_ramp()
|
||||
quantized = quantize_depth(depth, use_log=use_log, video_backend=None)
|
||||
recovered = dequantize_depth(
|
||||
quantized,
|
||||
use_log=use_log,
|
||||
output_unit=output_unit,
|
||||
output_tensor=False,
|
||||
output_channel_last=output_channel_last,
|
||||
)
|
||||
|
||||
expected_shape = (H, W, 1) if output_channel_last else (1, H, W)
|
||||
assert recovered.shape == expected_shape
|
||||
|
||||
recovered_m = recovered.astype(np.float32)
|
||||
if output_unit == "mm":
|
||||
recovered_m = recovered_m / 1000.0
|
||||
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
|
||||
|
||||
if use_log:
|
||||
# Log mode: tighter near-range error than far-range (the whole point).
|
||||
near = depth < 1.0
|
||||
far = depth > 8.0
|
||||
err_near = np.abs(recovered_2d[near] - depth[near])
|
||||
err_far = np.abs(recovered_2d[far] - depth[far])
|
||||
assert err_near.mean() < err_far.mean()
|
||||
else:
|
||||
# Linear mode: bounded by quant step + 1 mm of unit-conversion rounding.
|
||||
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
|
||||
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
def test_numpy_torch_agree(self, use_log, output_unit):
|
||||
"""Batched torch path produces the same values as the numpy path."""
|
||||
batch_size = 3
|
||||
per_frame = np.linspace(0, DEPTH_QMAX, H * W, dtype=np.uint16).reshape(H, W)
|
||||
batch_np = np.broadcast_to(per_frame[None, None, ...], (batch_size, 1, H, W)).copy()
|
||||
batch_t = torch.from_numpy(batch_np.astype(np.int32)) # torch.uint16 support is patchy.
|
||||
|
||||
ref = dequantize_depth(batch_np, use_log=use_log, output_unit=output_unit, output_tensor=False)
|
||||
out = dequantize_depth(batch_t, use_log=use_log, output_unit=output_unit, output_tensor=True)
|
||||
|
||||
assert isinstance(out, torch.Tensor)
|
||||
assert out.shape == (batch_size, 1, H, W)
|
||||
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
|
||||
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
|
||||
atol = 1e-5 if output_unit == "m" else 1.0
|
||||
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((H, W), (1, H, W)),
|
||||
((1, H, W), (1, H, W)),
|
||||
((H, W, 1), (1, H, W)),
|
||||
((3, 1, H, W), (3, 1, H, W)),
|
||||
((3, H, W, 1), (3, 1, H, W)),
|
||||
],
|
||||
)
|
||||
def test_input_layouts_accepted(self, input_shape, output_shape):
|
||||
"""All documented input layouts decode to the channel-first default."""
|
||||
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
|
||||
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
|
||||
assert out.shape == output_shape
|
||||
|
||||
def test_pyav_frame_roundtrip(self):
|
||||
"""quantize → av.VideoFrame → dequantize works."""
|
||||
depth = _depth_metres_ramp()
|
||||
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
|
||||
assert isinstance(frame, av.VideoFrame)
|
||||
|
||||
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
|
||||
assert recovered.shape == (1, H, W)
|
||||
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
|
||||
np.testing.assert_allclose(recovered[0], depth, atol=tol)
|
||||
|
||||
def test_invalid_log_params_raises(self):
|
||||
with pytest.raises(ValueError, match=r"depth_min \+ shift must be positive"):
|
||||
quantize_depth(_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
|
||||
|
||||
|
||||
# ── 2. Image writer depth support ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestImageWriterDepth:
|
||||
"""``image_array_to_pil_image`` and ``write_image`` for depth maps."""
|
||||
|
||||
@pytest.mark.parametrize("dtype,expected_mode", [(np.uint16, "I;16"), (np.float32, "F")])
|
||||
@pytest.mark.parametrize("shape", [(H, W), (H, W, 1), (1, H, W)])
|
||||
def test_pil_depth_modes_and_squeeze(self, dtype, expected_mode, shape):
|
||||
"""Single-channel depth converts to PIL with the right mode and (W, H) size."""
|
||||
arr = np.zeros(shape, dtype=dtype)
|
||||
img = image_array_to_pil_image(arr)
|
||||
assert img.mode == expected_mode
|
||||
assert img.size == (W, H)
|
||||
|
||||
def test_write_image_tiff_roundtrip(self, tmp_path):
|
||||
"""uint16 depth round-trips through .tiff."""
|
||||
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(arr, fpath)
|
||||
with PIL.Image.open(fpath) as loaded:
|
||||
recovered = np.array(loaded)
|
||||
np.testing.assert_array_equal(recovered, arr)
|
||||
|
||||
|
||||
# ── 3. Hardware-feature → depth flag ──────────────────────────────────
|
||||
|
||||
|
||||
class TestHwToDatasetFeaturesDepth:
|
||||
"""``hw_to_dataset_features`` flags single-channel cameras as depth."""
|
||||
|
||||
@pytest.mark.parametrize("channels,is_depth", [(1, True), (3, False)])
|
||||
def test_depth_marker_by_channels(self, channels, is_depth):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
features = hw_to_dataset_features({"cam": (480, 640, channels)}, prefix="observation")
|
||||
assert features["observation.images.cam"]["info"]["is_depth_map"] is is_depth
|
||||
|
||||
def test_invalid_channel_count_raises(self):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
with pytest.raises(ValueError, match="Expected a 3-tuple"):
|
||||
hw_to_dataset_features({"cam": (480, 640, 2)}, prefix="observation")
|
||||
|
||||
|
||||
# ── 4. Feature-to-file-format routing ────────────────────────────────
|
||||
|
||||
|
||||
# Keys derived from DUMMY_CAMERA_FEATURES_WITH_DEPTH; pick one RGB and the depth camera.
|
||||
RGB_KEY = next(iter(DUMMY_CAMERA_FEATURES))
|
||||
DEPTH_KEY = next(iter(DUMMY_DEPTH_CAMERA_FEATURES))
|
||||
|
||||
|
||||
class TestFeatureFileRouting:
|
||||
"""Depth vs RGB features route to the correct file format."""
|
||||
|
||||
NUM_FRAMES = 5
|
||||
|
||||
def test_image_mode_depth_tiff_rgb_png(self, tmp_path, features_factory):
|
||||
"""Without video encoding: depth → .tiff, RGB → .png."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=False)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / "ds",
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=self.NUM_FRAMES)
|
||||
|
||||
buf = dataset.writer.episode_buffer
|
||||
assert all(Path(p).suffix == ".tiff" for p in buf[DEPTH_KEY])
|
||||
assert all(Path(p).suffix == ".png" for p in buf[RGB_KEY])
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
def test_video_mode_depth_uses_depth_encoder(self, tmp_path, features_factory):
|
||||
"""With streaming video encoding: depth → DepthEncoderConfig, RGB does not."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=True)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / "ds",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=self.NUM_FRAMES)
|
||||
|
||||
encoder = dataset.writer._streaming_encoder
|
||||
assert encoder is not None
|
||||
assert isinstance(encoder._threads[DEPTH_KEY].video_encoder, DepthEncoderConfig)
|
||||
assert not isinstance(encoder._threads[RGB_KEY].video_encoder, DepthEncoderConfig)
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
@@ -1,121 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
|
||||
from lerobot.datasets.mp4 import (
|
||||
_box,
|
||||
_co64,
|
||||
_dinf,
|
||||
_hdlr,
|
||||
_mdhd,
|
||||
_mvhd,
|
||||
_stco,
|
||||
_stsc_one_sample_per_chunk,
|
||||
_stss,
|
||||
_stsz,
|
||||
_stts,
|
||||
_tkhd,
|
||||
_vmhd,
|
||||
parse_mp4_index,
|
||||
synthesize_mp4,
|
||||
)
|
||||
|
||||
|
||||
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
|
||||
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
sizes = np.array([10, 10, 10], dtype=np.int64)
|
||||
durations = np.array([1000, 1000, 1000], dtype=np.int64)
|
||||
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
|
||||
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offsets
|
||||
+ _stss(np.array([1], dtype=np.int64)),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
|
||||
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
|
||||
mdat_payload_start = 10_000
|
||||
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
|
||||
assert free_size >= 8
|
||||
free = _box(b"free", b"\0" * (free_size - 8))
|
||||
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
|
||||
|
||||
|
||||
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
assert sample_slice.byte_offset == 10_000
|
||||
assert sample_slice.byte_length == 60
|
||||
assert sample_slice.sample_lo == 0
|
||||
assert sample_slice.sample_hi == 2
|
||||
|
||||
|
||||
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
|
||||
mini_index = parse_mp4_index("mini.mp4", mini)
|
||||
|
||||
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
|
||||
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
|
||||
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
|
||||
|
||||
|
||||
def test_parser_accepts_co64_chunk_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
|
||||
|
||||
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps(
|
||||
{
|
||||
"url": "https://github.com/huggingface/huggingface_hub.git",
|
||||
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_hf_hub_range_cache_branch()
|
||||
@@ -94,7 +94,7 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
with pytest.raises(ValueError, match="Unsupported single-channel image dtype"):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -391,7 +385,8 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
# Verify codec options include thread tuning for libsvtav1 (lp=…)
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options
|
||||
codec_opts = thread.video_encoder.get_codec_options(encoder_threads=thread.encoder_threads)
|
||||
assert "svtav1-params" in codec_opts or "threads" in codec_opts
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
|
||||
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.datasets.image_writer import write_image
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
@@ -37,7 +37,15 @@ from lerobot.datasets.video_utils import (
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_VIDEO_INFO
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_DEPTH_FEATURES,
|
||||
DUMMY_DEPTH_KEY,
|
||||
DUMMY_DEPTH_VIDEO_INFO_FULL,
|
||||
DUMMY_VIDEO_FEATURES,
|
||||
DUMMY_VIDEO_INFO,
|
||||
DUMMY_VIDEO_KEY,
|
||||
)
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
|
||||
# Per-codec skip markers — validation tests only fire when the codec is available
|
||||
@@ -48,12 +56,67 @@ def _require_encoder(vcodec: str) -> pytest.MarkDecorator:
|
||||
|
||||
require_libsvtav1 = _require_encoder("libsvtav1")
|
||||
require_h264 = _require_encoder("h264")
|
||||
require_hevc = _require_encoder("hevc")
|
||||
require_videotoolbox = _require_encoder("h264_videotoolbox")
|
||||
require_nvenc = _require_encoder("h264_nvenc")
|
||||
require_vaapi = _require_encoder("h264_vaapi")
|
||||
require_qsv = _require_encoder("h264_qsv")
|
||||
|
||||
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
|
||||
|
||||
|
||||
def _write_color_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
|
||||
|
||||
|
||||
def _write_depth_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
"""Write synthetic uint16 depth TIFFs (millimetres) for depth encoder tests.
|
||||
|
||||
Uses a smooth linear ramp + per-frame offset (not white noise) so HEVC Main 12
|
||||
on ``gray12le`` compresses well. Values span ~100 mm to 10 m, covering most
|
||||
of the default ``[DEPTH_MIN, DEPTH_MAX]`` metres range after
|
||||
``quantize_depth(input_unit="auto"="mm")``.
|
||||
"""
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
base = np.linspace(100.0, 10_000.0, height * width, dtype=np.float32).reshape(height, width)
|
||||
for i in range(num_frames):
|
||||
arr = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.tiff")
|
||||
|
||||
|
||||
def _encode_video(
|
||||
path: Path,
|
||||
num_frames: int = 4,
|
||||
fps: int = 30,
|
||||
cfg: VideoEncoderConfig | None = None,
|
||||
depth: bool = False,
|
||||
) -> Path:
|
||||
"""Write synthetic frames to a temp dir and encode them to ``path``.
|
||||
|
||||
``depth=False`` writes uint8 RGB PNG noise and encodes with ``cfg``
|
||||
(defaulting to the library default). ``depth=True`` writes synthetic uint16
|
||||
depth TIFFs and encodes with ``cfg`` or a default :class:`DepthEncoderConfig`
|
||||
(HEVC Main 12 / ``gray12le``).
|
||||
"""
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
if depth:
|
||||
_write_depth_frames(imgs_dir, num_frames=num_frames)
|
||||
cfg = cfg or DepthEncoderConfig()
|
||||
else:
|
||||
_write_color_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, video_encoder=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
def _read_feature_info(dataset: LeRobotDataset, key: str = DUMMY_VIDEO_KEY) -> dict:
|
||||
info = json.loads((dataset.root / INFO_PATH).read_text())
|
||||
return info["features"][key]["info"]
|
||||
|
||||
|
||||
# ─── VideoEncoderConfig / codec options ──────────────────────────────
|
||||
|
||||
|
||||
@@ -87,7 +150,7 @@ class TestCodecOptions:
|
||||
assert opts["q:v"] == 40
|
||||
assert "crf" not in opts
|
||||
|
||||
@_require_encoder("h264_nvenc")
|
||||
@require_nvenc
|
||||
def test_nvenc_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_nvenc", g=2, crf=25, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
@@ -96,12 +159,12 @@ class TestCodecOptions:
|
||||
assert "crf" not in opts
|
||||
assert opts["g"] == 2
|
||||
|
||||
@_require_encoder("h264_vaapi")
|
||||
@require_vaapi
|
||||
def test_vaapi_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_vaapi", crf=28, preset=None)
|
||||
assert cfg.get_codec_options()["qp"] == 28
|
||||
|
||||
@_require_encoder("h264_qsv")
|
||||
@require_qsv
|
||||
def test_qsv_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_qsv", crf=25, preset=None)
|
||||
assert cfg.get_codec_options()["global_quality"] == 25
|
||||
@@ -313,59 +376,6 @@ class TestEncoderDetection:
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
|
||||
|
||||
# Default video feature set used by persistence tests.
|
||||
VIDEO_FEATURES = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
VIDEO_KEY = "observation.images.cam"
|
||||
|
||||
|
||||
def _write_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
|
||||
|
||||
|
||||
def _encode_video(
|
||||
path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None
|
||||
) -> Path:
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
_write_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
def _read_feature_info(dataset: LeRobotDataset) -> dict:
|
||||
info = json.loads((dataset.root / INFO_PATH).read_text())
|
||||
return info["features"][VIDEO_KEY]["info"]
|
||||
|
||||
|
||||
def _add_frames(dataset: LeRobotDataset, num_frames: int, video_keys: list[str] | None = None) -> None:
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
|
||||
if video_keys is None:
|
||||
video_keys = dataset.meta.video_keys
|
||||
for _ in range(num_frames):
|
||||
frame: dict = {"task": "test"}
|
||||
for key, ft in dataset.meta.features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.zeros(shape, dtype=np.float32)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
class TestGetVideoInfo:
|
||||
def test_returns_all_stream_fields(self):
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4")
|
||||
@@ -375,7 +385,7 @@ class TestGetVideoInfo:
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
assert "video.g" not in info
|
||||
assert "video.crf" not in info
|
||||
@@ -385,7 +395,7 @@ class TestGetVideoInfo:
|
||||
def test_merges_encoder_config_as_video_prefixed_entries(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
|
||||
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
@@ -398,11 +408,16 @@ class TestGetVideoInfo:
|
||||
def test_stream_derived_keys_take_precedence_over_config(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
|
||||
|
||||
assert info["video.codec"] # populated from stream, not from config's vcodec
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
def test_depth_encoder_config_sets_is_depth_map_true(self):
|
||||
"""A ``DepthEncoderConfig`` causes ``get_video_info`` to mark the stream as depth."""
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=DepthEncoderConfig())
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
|
||||
class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
@@ -434,7 +449,7 @@ class TestEncodeVideoFrames:
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
_write_color_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing content"
|
||||
video_path.write_bytes(sentinel)
|
||||
@@ -446,7 +461,7 @@ class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
_write_color_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
@@ -461,7 +476,7 @@ class TestEncodeVideoFrames:
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
|
||||
|
||||
info = get_video_info(video_path, camera_encoder=cfg)
|
||||
info = get_video_info(video_path, video_encoder=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
@@ -470,7 +485,7 @@ class TestEncodeVideoFrames:
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
# Encoder config
|
||||
assert info["video.g"] == 4
|
||||
@@ -488,14 +503,14 @@ class TestReencodeVideo:
|
||||
src = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
|
||||
out = tmp_path / "reencoded.mp4"
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv444p")
|
||||
reencode_video(src, out, camera_encoder=cfg, overwrite=True)
|
||||
reencode_video(src, out, video_encoder=cfg, overwrite=True)
|
||||
|
||||
assert out.exists()
|
||||
with av.open(str(out)) as container:
|
||||
n_frames = sum(1 for _ in container.decode(video=0))
|
||||
assert n_frames == 4
|
||||
|
||||
info = get_video_info(out, camera_encoder=cfg)
|
||||
info = get_video_info(out, video_encoder=cfg)
|
||||
assert info["video.codec"] == "h264"
|
||||
assert info["video.pix_fmt"] == "yuv444p"
|
||||
assert info["video.height"] == 64
|
||||
@@ -509,7 +524,7 @@ class TestReencodeVideo:
|
||||
src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4"
|
||||
out = tmp_path / "trim_window.mp4"
|
||||
cfg = VideoEncoderConfig(vcodec="h264")
|
||||
reencode_video(src, out, camera_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
|
||||
reencode_video(src, out, video_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
frames = list(container.decode(video=0))
|
||||
@@ -580,10 +595,10 @@ class TestEncoderConfigPersistence:
|
||||
def test_first_episode_save_persists_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
root=tmp_path / "ds", features=DUMMY_VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
@@ -603,14 +618,14 @@ class TestEncoderConfigPersistence:
|
||||
def test_second_episode_does_not_overwrite_encoder_fields(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
root=tmp_path / "ds", features=DUMMY_VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
first_info = dict(_read_feature_info(dataset))
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
@@ -637,3 +652,217 @@ class TestFromVideoInfo:
|
||||
# ``{}`` placeholder (typical after a merge with disagreeing sources)
|
||||
# must not leak into the reconstructed config.
|
||||
assert cfg.extra_options == VideoEncoderConfig().extra_options
|
||||
|
||||
|
||||
# ─── Depth-specific encoding tests ────────────────────────────────────
|
||||
|
||||
|
||||
class TestEncodeDepthVideoFrames:
|
||||
"""Depth mirror of :class:`TestEncodeVideoFrames`.
|
||||
|
||||
Exercises ``encode_video_frames`` end-to-end through
|
||||
:class:`DepthEncoderConfig` (HEVC Main 12 / ``gray12le``) on synthetic
|
||||
uint16 depth TIFFs.
|
||||
"""
|
||||
|
||||
@require_hevc
|
||||
def test_produces_readable_file(self, tmp_path):
|
||||
video_path = _encode_video(tmp_path / "out.mp4", depth=True)
|
||||
|
||||
assert video_path.exists()
|
||||
info = get_video_info(video_path, video_encoder=DepthEncoderConfig())
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.codec"] == "hevc"
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["video.channels"] == 1
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
@require_hevc
|
||||
def test_frame_count_and_duration_match_input(self, tmp_path):
|
||||
num_frames = 10
|
||||
fps = 30
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=num_frames, fps=fps, depth=True)
|
||||
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
actual_frames = sum(1 for _ in container.decode(stream))
|
||||
duration = (
|
||||
float(stream.duration * stream.time_base)
|
||||
if stream.duration is not None
|
||||
else float(container.duration / av.time_base)
|
||||
)
|
||||
|
||||
assert actual_frames == num_frames
|
||||
assert abs(duration - num_frames / fps) < 0.1
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
"""Codec-agnostic: file-system semantics must hold even without an HEVC encoder."""
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_depth_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing depth content"
|
||||
video_path.write_bytes(sentinel)
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, video_encoder=DepthEncoderConfig(), overwrite=False)
|
||||
|
||||
assert video_path.read_bytes() == sentinel
|
||||
|
||||
@require_hevc
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_depth_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, video_encoder=DepthEncoderConfig(), overwrite=True)
|
||||
|
||||
info = get_video_info(video_path, video_encoder=DepthEncoderConfig())
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
@require_hevc
|
||||
def test_custom_encoder_config_fields_stored_in_info(self, tmp_path):
|
||||
"""All stream-derived and depth-encoder config fields are present after encoding."""
|
||||
cfg = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=4,
|
||||
crf=25,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.5,
|
||||
use_log=False,
|
||||
)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg, depth=True)
|
||||
|
||||
info = get_video_info(video_path, video_encoder=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.channels"] == 1
|
||||
assert info["video.codec"] == "hevc"
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["is_depth_map"] is True
|
||||
assert info["has_audio"] is False
|
||||
# Base encoder config
|
||||
assert info["video.g"] == 4
|
||||
assert info["video.crf"] == 25
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
# Depth-specific tuning
|
||||
assert info["video.depth_min"] == 0.05
|
||||
assert info["video.depth_max"] == 8.0
|
||||
assert info["video.shift"] == 2.5
|
||||
assert info["video.use_log"] is False
|
||||
|
||||
|
||||
class TestDepthEncoderConfigPersistence:
|
||||
"""Depth mirror of :class:`TestEncoderConfigPersistence`.
|
||||
|
||||
``DepthEncoderConfig`` must be stored as ``video.<field>`` entries
|
||||
(including the depth-specific ``depth_min`` / ``depth_max`` / ``shift`` /
|
||||
``use_log``) under ``info["features"][<depth_key>]["info"]`` when the
|
||||
first episode is saved.
|
||||
"""
|
||||
|
||||
@require_hevc
|
||||
def test_first_episode_save_persists_depth_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=2,
|
||||
crf=30,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.5,
|
||||
use_log=False,
|
||||
)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=DUMMY_DEPTH_FEATURES, use_videos=True, depth_encoder=cfg
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
info = _read_feature_info(dataset, key=DUMMY_DEPTH_KEY)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.codec"] == "hevc"
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["is_depth_map"] is True
|
||||
# Base encoder config
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
# Depth-specific tuning
|
||||
assert info["video.depth_min"] == 0.05
|
||||
assert info["video.depth_max"] == 8.0
|
||||
assert info["video.shift"] == 2.5
|
||||
assert info["video.use_log"] is False
|
||||
|
||||
@require_hevc
|
||||
def test_second_episode_does_not_overwrite_depth_encoder_fields(
|
||||
self, tmp_path, empty_lerobot_dataset_factory
|
||||
):
|
||||
cfg = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=2,
|
||||
crf=30,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.5,
|
||||
use_log=False,
|
||||
)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=DUMMY_DEPTH_FEATURES, use_videos=True, depth_encoder=cfg
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
first_info = dict(_read_feature_info(dataset, key=DUMMY_DEPTH_KEY))
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert _read_feature_info(dataset, key=DUMMY_DEPTH_KEY) == first_info
|
||||
|
||||
|
||||
class TestDepthFromVideoInfo:
|
||||
"""``DepthEncoderConfig.from_video_info`` reconstructs a depth encoder
|
||||
config from the ``video.*`` keys persisted in a dataset's ``info.json``.
|
||||
|
||||
Depth mirror of :class:`TestFromVideoInfo`.
|
||||
"""
|
||||
|
||||
@require_hevc
|
||||
def test_reconstructs_from_dummy_depth_video_info(self):
|
||||
cfg = DepthEncoderConfig.from_video_info(DUMMY_DEPTH_VIDEO_INFO_FULL)
|
||||
|
||||
# No alias for ``"hevc"``; the canonical stream codec is reused as-is.
|
||||
assert cfg.vcodec == "hevc"
|
||||
assert cfg.pix_fmt == DUMMY_DEPTH_VIDEO_INFO_FULL["video.pix_fmt"]
|
||||
assert cfg.g == DUMMY_DEPTH_VIDEO_INFO_FULL["video.g"]
|
||||
assert cfg.crf == DUMMY_DEPTH_VIDEO_INFO_FULL["video.crf"]
|
||||
assert cfg.fast_decode == DUMMY_DEPTH_VIDEO_INFO_FULL["video.fast_decode"]
|
||||
assert cfg.video_backend == DUMMY_DEPTH_VIDEO_INFO_FULL["video.video_backend"]
|
||||
# ``{}`` placeholder (typical after a merge with disagreeing sources)
|
||||
# must not leak into the reconstructed config.
|
||||
assert cfg.extra_options == DepthEncoderConfig().extra_options
|
||||
# Depth-specific tuning round-trips through ``info.json``.
|
||||
assert cfg.depth_min == DUMMY_DEPTH_VIDEO_INFO_FULL["video.depth_min"]
|
||||
assert cfg.depth_max == DUMMY_DEPTH_VIDEO_INFO_FULL["video.depth_max"]
|
||||
assert cfg.shift == DUMMY_DEPTH_VIDEO_INFO_FULL["video.shift"]
|
||||
assert cfg.use_log == DUMMY_DEPTH_VIDEO_INFO_FULL["video.use_log"]
|
||||
|
||||
Vendored
+45
-1
@@ -39,12 +39,56 @@ DUMMY_VIDEO_INFO = {
|
||||
"video.crf": 30,
|
||||
"video.preset": 12,
|
||||
"video.fast_decode": 0,
|
||||
"video.is_depth_map": False,
|
||||
"is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
|
||||
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
|
||||
}
|
||||
DUMMY_DEPTH_VIDEO_INFO = {
|
||||
**DUMMY_VIDEO_INFO,
|
||||
"is_depth_map": True,
|
||||
}
|
||||
DUMMY_DEPTH_VIDEO_INFO_FULL = {
|
||||
**{k: v for k, v in DUMMY_VIDEO_INFO.items() if k != "video.preset"},
|
||||
"video.codec": "hevc",
|
||||
"video.pix_fmt": "gray12le",
|
||||
"is_depth_map": True,
|
||||
"video.depth_min": 0.05,
|
||||
"video.depth_max": 8.0,
|
||||
"video.shift": 2.5,
|
||||
"video.use_log": True,
|
||||
}
|
||||
DUMMY_DEPTH_CAMERA_FEATURES = {
|
||||
"laptop_depth": {
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": DUMMY_DEPTH_VIDEO_INFO,
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
||||
# Default video feature set used by video-encoding persistence tests.
|
||||
DUMMY_VIDEO_FEATURES = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
DUMMY_VIDEO_KEY = "observation.images.cam"
|
||||
|
||||
DUMMY_DEPTH_FEATURES = {
|
||||
"observation.images.depth": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": True},
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
DUMMY_DEPTH_KEY = "observation.images.depth"
|
||||
|
||||
Vendored
+38
@@ -49,6 +49,39 @@ from tests.fixtures.constants import (
|
||||
)
|
||||
|
||||
|
||||
def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
"""Append ``num_frames`` synthetic frames to ``dataset``.
|
||||
|
||||
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
|
||||
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
|
||||
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
|
||||
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
|
||||
"""
|
||||
video_keys = dataset.meta.video_keys
|
||||
depth_keys = dataset.meta.depth_keys
|
||||
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
|
||||
# encode (HEVC Main 12 hates white noise).
|
||||
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
|
||||
for i in range(num_frames):
|
||||
frame: dict = {"task": "test"}
|
||||
for key, ft in dataset.meta.features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if key in depth_keys:
|
||||
h, w, _ = shape
|
||||
base = _depth_base_cache.setdefault(
|
||||
(h, w),
|
||||
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
)
|
||||
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
elif key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.zeros(shape, dtype=np.float32)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
@@ -485,10 +518,14 @@ def lerobot_dataset_factory(
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
camera_features: dict | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info_kwargs = {}
|
||||
if camera_features is not None:
|
||||
info_kwargs["camera_features"] = camera_features
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
@@ -496,6 +533,7 @@ def lerobot_dataset_factory(
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
chunks_size=chunks_size,
|
||||
**info_kwargs,
|
||||
)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info.features)
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
ReencodeVideosConfig,
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
_validate_config,
|
||||
@@ -103,3 +104,47 @@ class TestOperationTypeParsing:
|
||||
)
|
||||
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
||||
assert resolved_name == type_name
|
||||
|
||||
|
||||
class TestDepthEncoderParsing:
|
||||
"""Test that the depth encoder is exposed and parsed for video operations."""
|
||||
|
||||
def test_reencode_has_default_depth_encoder(self):
|
||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", "reencode_videos"])
|
||||
assert isinstance(cfg.operation, ReencodeVideosConfig)
|
||||
# A depth encoder is configured by default so depth videos are re-encoded too.
|
||||
assert cfg.operation.depth_encoder is not None
|
||||
assert hasattr(cfg.operation.depth_encoder, "depth_min")
|
||||
|
||||
def test_reencode_parses_depth_encoder_overrides(self):
|
||||
cfg = parse_cfg(
|
||||
[
|
||||
"--repo_id",
|
||||
"test/repo",
|
||||
"--operation.type",
|
||||
"reencode_videos",
|
||||
"--operation.depth_encoder.vcodec",
|
||||
"ffv1",
|
||||
"--operation.depth_encoder.depth_max",
|
||||
"12.0",
|
||||
"--operation.depth_encoder.use_log",
|
||||
"false",
|
||||
]
|
||||
)
|
||||
assert cfg.operation.depth_encoder.vcodec == "ffv1"
|
||||
assert cfg.operation.depth_encoder.depth_max == 12.0
|
||||
assert cfg.operation.depth_encoder.use_log is False
|
||||
|
||||
def test_convert_image_to_video_parses_depth_encoder_overrides(self):
|
||||
cfg = parse_cfg(
|
||||
[
|
||||
"--repo_id",
|
||||
"test/repo",
|
||||
"--operation.type",
|
||||
"convert_image_to_video",
|
||||
"--operation.depth_encoder.depth_min",
|
||||
"0.05",
|
||||
]
|
||||
)
|
||||
assert isinstance(cfg.operation, ConvertImageToVideoConfig)
|
||||
assert cfg.operation.depth_encoder.depth_min == 0.05
|
||||
|
||||
@@ -43,6 +43,11 @@ def mock_rerun(monkeypatch):
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
|
||||
class DummyDepthImage:
|
||||
def __init__(self, arr, colormap=None):
|
||||
self.arr = arr
|
||||
self.colormap = colormap
|
||||
|
||||
def dummy_log(key, obj=None, **kwargs):
|
||||
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
|
||||
if obj is None and "entity" in kwargs:
|
||||
@@ -55,6 +60,8 @@ def mock_rerun(monkeypatch):
|
||||
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
DepthImage=DummyDepthImage,
|
||||
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
|
||||
log=dummy_log,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
@@ -225,7 +232,7 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
assert temp.value == pytest.approx(10.0)
|
||||
|
||||
img = _obj_for(calls, "observation.gray")
|
||||
assert type(img).__name__ == "DummyImage"
|
||||
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
|
||||
assert img.arr.shape == (8, 8, 1) # remains HWC
|
||||
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
|
||||
@@ -1089,8 +1089,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?branch=main#06fcc085fcdd22fc5cc741954f6187dd879543b6" }
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1107,6 +1107,10 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -1143,7 +1147,7 @@ name = "decord"
|
||||
version = "0.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy", marker = "(platform_machine != 'arm64' and platform_machine != 's390x' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" },
|
||||
@@ -2046,8 +2050,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.20.0.dev0"
|
||||
source = { git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads#5319b287faa73239bb40df16d69c39e5d6daf0f7" }
|
||||
version = "1.19.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "filelock" },
|
||||
@@ -2060,6 +2064,10 @@ dependencies = [
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/27/629cfe58c582f92ded066c4a07d1a057ff617118ab7973200f770bd853cb/huggingface_hub-1.19.0.tar.gz", hash = "sha256:fd771622182d40977272a923953ee3b1b13538f9f8a7f5d78398f10af0f1c0bd", size = 824721, upload-time = "2026-06-11T12:33:18.665Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/a5/558da89f66464d8d0229ff497e8b8666977de2d8cf48c28a2862ecf1250f/huggingface_hub-1.19.0-py3-none-any.whl", hash = "sha256:1dc72e1f6b4d6df6b30eb72e57d00514ef453d660f04af2b87f0e67267f31ee0", size = 693398, upload-time = "2026-06-11T12:33:16.695Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hydra-core"
|
||||
@@ -3179,7 +3187,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?branch=main" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
@@ -3202,7 +3210,7 @@ requires-dist = [
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user