mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6da993ef1a | |||
| eba2a696ab | |||
| b90f7e3590 | |||
| 73d60854b4 | |||
| 9dd7aee176 | |||
| addbf8d7e4 | |||
| 23f5459ba1 | |||
| f27582ca24 | |||
| 362233bf8f | |||
| eec9edf72d | |||
| dc6469d69a | |||
| ca89a8d3fe | |||
| 9b72f454e9 | |||
| f2a6838909 | |||
| 435465b00c | |||
| 05e1aad501 | |||
| 243adf6695 | |||
| 06e9442149 | |||
| 05c9a9c095 | |||
| 6181d6b694 | |||
| 971d7ea85b | |||
| 56231b17d1 | |||
| 12a66324fb | |||
| 0cb8970283 | |||
| 5498bac1b0 | |||
| 05d2a6062d | |||
| a1ec48d3a9 | |||
| c0d19ef35b | |||
| dd20029f4b | |||
| e961f8fec0 | |||
| ba7f23adf9 | |||
| a062ffb45c | |||
| dfdeac1339 | |||
| beaaaa3d99 | |||
| 8afd367c6a | |||
| 3c2a990ac3 | |||
| 4610d78c8c | |||
| 30dbe0a71b | |||
| 7460d2a796 | |||
| 3cbde39767 | |||
| d818a68177 | |||
| e227adb64f | |||
| 90f6f4c1d7 | |||
| 597f7b063c | |||
| ea7bb153e0 | |||
| a930fa8ca5 | |||
| e7191fc3ad | |||
| 712912d946 | |||
| 3826531a95 | |||
| 325b351ff2 | |||
| 44461eaadc | |||
| 330f63bf87 | |||
| 5d56804d81 | |||
| a6e95c4d26 | |||
| 10bb300e8a | |||
| e5e241e2cb | |||
| b8ddd64120 | |||
| 38327fdc84 | |||
| 9555efc02c | |||
| d576c59afb |
@@ -157,6 +157,14 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Working with depth
|
||||
|
||||
The Intel RealSense and Reachy 2 cameras can capture both color and depth in lockstep. Calling `read()` returns the **color** frame as `(H, W, 3)` `uint8`. Calling `read_depth()` returns the **depth map** as `(H, W, 1)` `uint16`, where each pixel value is the distance from the sensor expressed in **millimetres**. A pixel value of `0` typically means "no measurement available" (out-of-range, occluded, or low-confidence).
|
||||
|
||||
During recording, the control loop peeks the freshest buffered frames non-blockingly via `read_latest()` (color) and `read_latest_depth()` (depth), adding the depth map as a sibling feature (e.g. `front_depth` next to `front`).
|
||||
|
||||
For how depth streams are stored and encoded when recording a dataset, see the [Depth streams](./video_encoding_parameters#depth-streams) section of the video encoding guide.
|
||||
|
||||
## Use your phone's camera
|
||||
|
||||
<hfoptions id="use phone">
|
||||
|
||||
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
||||
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `bi_openarm_mini` - Bimanual OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
|
||||
@@ -117,7 +117,7 @@ lerobot-rollout \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
|
||||
@@ -11,8 +11,9 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage (RGB and depth cameras are encoded with separate encoders)
|
||||
7. **Re-encode Videos** - Re-encode an existing video dataset's RGB and/or depth streams with new encoder settings
|
||||
8. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -122,6 +123,15 @@ lerobot-edit-dataset \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert a dataset that includes depth maps, customizing the depth encoder
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.depth_encoder.depth_min 0.01 \
|
||||
--operation.depth_encoder.depth_max 10.0 \
|
||||
--operation.depth_encoder.use_log true
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -147,11 +157,42 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `camera_encoder`: Video encoder settings applied to RGB cameras — all sub-fields accessible via `--operation.camera_encoder.<field>`. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `depth_encoder`: Video encoder settings applied to depth-map cameras (e.g. from an Intel RealSense). In addition to the standard encoder fields it exposes the depth quantization knobs (`depth_min`, `depth_max`, `shift`, `use_log`), accessible via `--operation.depth_encoder.<field>`. These quantization settings are persisted to the dataset metadata so depth can be dequantized back to physical units on load. See the [Depth streams](./video_encoding_parameters#depth-streams) section for details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). Depth-map cameras are detected automatically and routed to the `depth_encoder`, while RGB cameras use the `camera_encoder`. All episodes, stats, and tasks are preserved.
|
||||
|
||||
#### Re-encode Videos
|
||||
|
||||
Re-encode the videos of an existing video dataset with different encoder settings, without going back to raw frames. RGB videos use the `camera_encoder` and depth videos use the `depth_encoder`. Provide only the encoder(s) you want to re-encode; the other stream type is left untouched.
|
||||
|
||||
```bash
|
||||
# Re-encode all RGB videos with new settings (saves to lerobot/pusht_reencoded by default)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec h264 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.crf 23
|
||||
|
||||
# Re-encode both RGB and depth videos in a dataset with depth maps
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_depth \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec libx264 \
|
||||
--operation.depth_encoder.vcodec ffv1
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `camera_encoder`: Encoder settings applied to every RGB video. Omit to skip re-encoding RGB videos.
|
||||
- `depth_encoder`: Encoder settings applied to every depth video. Omit to skip re-encoding depth videos.
|
||||
- `num_workers`: Number of parallel workers for processing.
|
||||
|
||||
> [!NOTE]
|
||||
> When re-encoding depth videos, the existing depth quantization parameters (`depth_min`, `depth_max`, `shift`, `use_log`) and the `is_depth_map` flag are **preserved** — re-encoding only changes the codec/quality of the stored stream, not how depth is dequantized on load.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
|
||||
@@ -65,6 +65,76 @@ All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
---
|
||||
|
||||
## Depth streams
|
||||
|
||||
Depth maps (Intel RealSense, Reachy 2) are stored as their **own video streams** alongside the RGB streams. Raw depth (`uint16` millimetres or `float32` metres) can't survive an 8-bit codec, so LeRobot **quantizes** each map to a 12-bit code (`[0, 4095]`) — logarithmically by default, to match the `1/depth` error profile of depth sensors — then packs it into a high-bit-depth pixel format (`gray12le`) and encodes it with a 12-bit codec.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
A["Raw depth (uint16 mm / float32 m)"] --> B["Clip to depth_min, depth_max"]
|
||||
B --> C["Quantize to 12-bit code 0–4095 (log or linear)"]
|
||||
C --> D["Pack into gray12le"]
|
||||
D --> E["Encode video (hevc Main 12)"]
|
||||
E --> F[("MP4 + metadata: depth_min/max, shift, use_log")]
|
||||
F -. "load time (depth_output_unit)" .-> G["Dequantize to mm or m"]
|
||||
|
||||
classDef input fill:#e3f2fd,stroke:#1565c0,color:#0d47a1;
|
||||
classDef encode fill:#ede7f6,stroke:#5e35b1,color:#311b92;
|
||||
classDef store fill:#fff8e1,stroke:#f9a825,color:#e65100;
|
||||
classDef load fill:#e8f5e9,stroke:#2e7d32,color:#1b5e20;
|
||||
|
||||
class A input;
|
||||
class B,C,D,E encode;
|
||||
class F store;
|
||||
class G load;
|
||||
```
|
||||
|
||||
Configure the depth pipeline through a parallel **`depth_encoder`** block (`DepthEncoderConfig`). It inherits every `VideoEncoderConfig` field (`vcodec`, `pix_fmt`, `crf`, …) and adds four quantizer knobs, set via `--dataset.depth_encoder.<field>`:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
... \
|
||||
--dataset.depth_encoder.vcodec=hevc \
|
||||
--dataset.depth_encoder.depth_min=0.05 \
|
||||
--dataset.depth_encoder.depth_max=5.0 \
|
||||
--dataset.depth_encoder.use_log=true
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ----------- | ------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"hevc"` | Defaults to HEVC Main 12 (a 12-bit-capable codec). `ffv1` is a lossless alternative. |
|
||||
| `pix_fmt` | `str` | `"gray12le"` | Single-channel 12-bit pixel format used to carry the quantized codes. |
|
||||
| `depth_min` | `float` | `0.01` | Depth in metres mapped to quantum `0`. Values below are clipped on decode. |
|
||||
| `depth_max` | `float` | `10.0` | Depth in metres mapped to quantum `4095`. Values above are clipped on decode. |
|
||||
| `shift` | `float` | `3.5` | Pre-log offset (metres) used in logarithmic quantization for numerical stability near zero. Must satisfy `depth_min + shift > 0`. |
|
||||
| `use_log` | `bool` | `True` | If `true`, quantize in log-space (recommended for typical depth sensors). Set to `false` for uniform/linear quantization. |
|
||||
|
||||
> [!TIP]
|
||||
> `depth_min`, `depth_max`, and `shift` are always interpreted in **metres**, regardless of the input depth's unit. Inputs are auto-detected: integer arrays (e.g. `uint16` millimetres straight from a RealSense) are treated as millimetres, floating arrays as metres.
|
||||
> Pick `depth_min` / `depth_max` to bracket the actual working range of your sensor — quanta outside that range saturate, which can crush detail at the boundaries.
|
||||
|
||||
Depth features are flagged with `"is_depth_map": true` in `meta/info.json`, and their quantizer settings (`video.depth_min`, `video.depth_max`, `video.shift`, `video.use_log`) are persisted — which is what lets depth be **dequantized back to physical units** on load.
|
||||
|
||||
### Output unit at load time
|
||||
|
||||
`depth_encoder` is a **record-time** concern. The unit that depth maps are dequantized to on _load_ (e.g. during training) is set separately by the read-time flag `--dataset.depth_output_unit`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.depth_output_unit=m \
|
||||
--policy.type=act
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------- | ----- | ------- | -------------------------------------------------------------------------------------------- |
|
||||
| `depth_output_unit` | `str` | `"mm"` | Physical unit depth maps are dequantized to on load: `"mm"` (millimetres) or `"m"` (metres). |
|
||||
|
||||
> [!TIP]
|
||||
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
@@ -82,7 +152,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
@@ -97,7 +167,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -281,7 +281,7 @@ class VideoFrameProvider:
|
||||
reencode_video(
|
||||
src,
|
||||
out_path,
|
||||
camera_encoder=encoder,
|
||||
video_encoder=encoder,
|
||||
overwrite=True,
|
||||
start_time_s=from_timestamp,
|
||||
end_time_s=to_timestamp,
|
||||
|
||||
@@ -105,8 +105,9 @@ def raw_observation_to_observation(
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
if image.dtype == torch.uint8:
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
@@ -436,7 +436,7 @@ class OpenCVCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
1. Reads a color frame (blocking call)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
@@ -445,8 +445,9 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
@@ -484,6 +485,8 @@ class OpenCVCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -268,13 +268,13 @@ class RealSenseCamera(Camera):
|
||||
)
|
||||
|
||||
if len(found_devices) > 1:
|
||||
serial_numbers = [dev["serial_number"] for dev in found_devices]
|
||||
serial_numbers = [dev["id"] for dev in found_devices]
|
||||
raise ValueError(
|
||||
f"Multiple RealSense cameras found with name '{name}'. "
|
||||
f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
|
||||
)
|
||||
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
serial_number = str(found_devices[0]["id"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
np.ndarray: The depth map as a NumPy array (height, width, 1)
|
||||
of type `np.uint16` (raw depth values in millimeters).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
1. Reads a color/depth frame (blocking call with 10s timeout)
|
||||
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -474,8 +474,9 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
@@ -486,6 +487,8 @@ class RealSenseCamera(Camera):
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
|
||||
processed_depth_frame = processed_depth_frame[..., np.newaxis]
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -522,6 +525,8 @@ class RealSenseCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive(): # pragma: no cover
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
@@ -532,7 +537,6 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -575,7 +579,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -611,6 +614,73 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[np.uint16]:
|
||||
"""Read the latest depth frame asynchronously, in millimeters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``, where each
|
||||
pixel is the distance from the sensor in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in millimeters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.uint16`` of shape ``(H, W, 1)``, where each pixel is the
|
||||
distance from the sensor in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
@@ -249,8 +249,9 @@ class ZMQCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
@@ -292,6 +293,8 @@ class ZMQCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -180,32 +180,24 @@ class WandBLogger:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
batch_data = {}
|
||||
for k, v in d.items():
|
||||
# Skip the custom step key here, it's added to the batch below.
|
||||
if custom_step_key is not None and k == custom_step_key:
|
||||
continue
|
||||
|
||||
if isinstance(v, list):
|
||||
for i, elem in enumerate(v):
|
||||
if isinstance(elem, (int | float)):
|
||||
batch_data[f"{mode}/{k}_{i}"] = elem
|
||||
continue
|
||||
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data[f"{mode}/{k}"] = v
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if batch_data:
|
||||
if custom_step_key is not None:
|
||||
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
||||
self._wandb.log(batch_data)
|
||||
else:
|
||||
self._wandb.log(data=batch_data, step=step)
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -35,8 +35,11 @@ from .types import (
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -57,8 +60,12 @@ __all__ = [
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
"DepthEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
# Factories
|
||||
"encoder_config_from_video_info",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,6 +60,8 @@ class DatasetRecordConfig:
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -35,12 +35,17 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
|
||||
# Has no effect on datasets without depth cameras.
|
||||
depth_output_unit: str = "mm"
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.depth_output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar, Self
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -36,11 +36,12 @@ HW_VIDEO_CODECS = [
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
|
||||
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
|
||||
)
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
@@ -52,6 +53,19 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
# Default depth quantization and encoding parameters.
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
@@ -86,6 +100,10 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Source-data channel count this encoder is expected to handle (3 for RGB,
|
||||
# 1 for depth, etc.)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 3
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
@@ -94,9 +112,9 @@ class VideoEncoderConfig:
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Parse the ``video.*`` keys of a feature ``info`` block into
|
||||
constructor kwargs.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
@@ -115,7 +133,15 @@ class VideoEncoderConfig:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> Self:
|
||||
"""Reconstruct an encoder config from a video feature's ``info`` block.
|
||||
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
return cls(**cls._kwargs_from_video_info(video_info))
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
@@ -138,7 +164,9 @@ class VideoEncoderConfig:
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
check_video_encoder_parameters_pyav(
|
||||
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
|
||||
)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
@@ -218,6 +246,10 @@ class VideoEncoderConfig:
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
@@ -233,3 +265,75 @@ class VideoEncoderConfig:
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantizer.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"gray12le"``.
|
||||
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "gray12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
|
||||
``shift`` / ``use_log``) on top of the base parser. Missing keys
|
||||
fall back to the class defaults.
|
||||
"""
|
||||
kwargs = super()._kwargs_from_video_info(video_info)
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
|
||||
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Build the appropriate encoder config from a feature's ``info`` block.
|
||||
|
||||
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
|
||||
as a depth map and to :class:`VideoEncoderConfig`
|
||||
otherwise.
|
||||
|
||||
Args:
|
||||
video_info: A feature's ``info`` dict as persisted in ``info.json``,
|
||||
or ``None`` (treated as an empty dict).
|
||||
|
||||
Returns:
|
||||
A :class:`DepthEncoderConfig` for depth features, otherwise a
|
||||
:class:`VideoEncoderConfig`.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
|
||||
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
|
||||
return cls.from_video_info(video_info)
|
||||
|
||||
@@ -506,8 +506,10 @@ def compute_episode_stats(
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
For 'image'/'video' features, stats are computed per channel and kept with a
|
||||
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
|
||||
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
|
||||
this rescaling and remain in their stored units.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
@@ -531,8 +533,12 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = (
|
||||
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
)
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -552,8 +558,10 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
|
||||
raise ValueError(
|
||||
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
|
||||
)
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -337,6 +337,25 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos or images.
|
||||
|
||||
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||
"""
|
||||
|
||||
def _is_depth(ft: dict) -> bool:
|
||||
info = ft.get("info") or {}
|
||||
video_info = ft.get("video_info") or {}
|
||||
return (
|
||||
info.get("is_depth_map", False)
|
||||
or info.get("video.is_depth_map", False)
|
||||
or video_info.get("video.is_depth_map", False)
|
||||
)
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -580,29 +599,51 @@ class LeRobotDatasetMetadata:
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
preserve_keys: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
"""Populate or refresh per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Two modes, selected by ``preserve_keys``:
|
||||
|
||||
- **Populate** (``None``, default): write info for video keys that lack it,
|
||||
skip the rest. Used when first encoding a dataset.
|
||||
- **Refresh** (any iterable): re-probe and overwrite existing info, keeping
|
||||
the listed keys. Used after re-encoding to preserve data-intrinsic entries
|
||||
(``is_depth_map``, depth quantization params) while codec params change.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
video_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
preserve_keys: ``None`` (default) for populate-once mode. An iterable
|
||||
(possibly empty) switches to refresh mode, keeping these keys'
|
||||
existing values while recomputing the rest.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
refresh = preserve_keys is not None
|
||||
preserve_set = set(preserve_keys or ())
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
existing = self.features[key].get("info") or {}
|
||||
# ``is_depth_map`` is written at feature creation and does not count as real video info here.
|
||||
already_populated = bool(set(existing.keys()) - {"is_depth_map"})
|
||||
# Populate-once: never clobber info that has already been written unless a refresh is requested.
|
||||
if already_populated and not refresh:
|
||||
continue
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
new_info = get_video_info(video_path, video_encoder=video_encoder)
|
||||
# Drop preserved keys so the existing values win on merge.
|
||||
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
|
||||
self.info.features[key]["info"] = {**existing, **new_info}
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
|
||||
@@ -22,7 +22,10 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import DepthEncoderConfig
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
@@ -51,6 +54,7 @@ class DatasetReader:
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
@@ -68,6 +72,10 @@ class DatasetReader:
|
||||
relative timestamp offsets for temporal context windows.
|
||||
image_transforms: Optional torchvision v2 transform applied to
|
||||
visual features.
|
||||
return_uint8: If True, return RGB video frames as raw uint8 tensors
|
||||
instead of normalized float32.
|
||||
depth_output_unit: Physical unit depth maps are dequantized to
|
||||
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self.root = root
|
||||
@@ -76,6 +84,7 @@ class DatasetReader:
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
@@ -86,6 +95,12 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
@@ -247,7 +262,18 @@ class DatasetReader:
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
depth_encoder = self._depth_encoder_configs[vid_key]
|
||||
frames = dequantize_depth(
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_unit=self._depth_output_unit,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -36,7 +36,14 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -47,6 +54,7 @@ from .compute_stats import (
|
||||
compute_relative_action_stats,
|
||||
)
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .image_writer import write_image
|
||||
from .io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
@@ -61,12 +69,13 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEPTH_FILE_PATTERN,
|
||||
IMAGE_FILE_PATTERN,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
@@ -600,7 +609,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -614,7 +623,7 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
camera_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
video_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
"""
|
||||
from fractions import Fraction
|
||||
|
||||
@@ -639,13 +648,13 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
codec_options = camera_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
codec_options = video_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(video_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = camera_encoder.pix_fmt
|
||||
v_out.pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -732,7 +741,7 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
video_encoder = encoder_config_from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
@@ -816,7 +825,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
camera_encoder,
|
||||
video_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1150,15 +1159,15 @@ def _save_episode_images_for_video(
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
write_image(item[img_key], imgs_dir / frame_pattern.format(frame_index=i))
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
@@ -1190,13 +1199,14 @@ def _save_batch_episodes_images(
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
write_image(item[img_key_param], imgs_dir / frame_pattern.format(frame_index=base_frame_idx + i))
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
@@ -1287,7 +1297,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1301,7 +1311,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
video_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1326,10 +1336,11 @@ def _estimate_frame_size_via_calibration(
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames
|
||||
# Save calibration frames using the suffix/format the encoder expects.
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
for i, idx in enumerate(sample_indices):
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
write_image(hf_dataset[idx][img_key], calibration_dir / frame_pattern.format(frame_index=i))
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
@@ -1337,7 +1348,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1610,6 +1621,7 @@ def recompute_stats(
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
# TODO: enable image and video stats re-computation
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
@@ -1656,6 +1668,7 @@ def convert_image_to_video_dataset(
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1667,21 +1680,32 @@ def convert_image_to_video_dataset(
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder: Video encoder settings
|
||||
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
dataset: The source LeRobot dataset with images.
|
||||
output_dir: Root directory where the converted dataset will be stored. When
|
||||
``None``, defaults to ``$HF_LEROBOT_HOME/repo_id``. Equivalent to
|
||||
``new_root`` in ``EditDatasetConfig``.
|
||||
repo_id: Converted dataset identifier. Equivalent to ``new_repo_id`` in
|
||||
``EditDatasetConfig``.
|
||||
camera_encoder: Video encoder settings applied to RGB cameras. When ``None``,
|
||||
:func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to depth-map cameras, including
|
||||
the quantization parameters persisted to the dataset metadata. When
|
||||
``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
episode_indices: Episode indices to convert. When ``None``, all episodes are
|
||||
converted.
|
||||
num_workers: Number of threads for parallel processing.
|
||||
max_episodes_per_batch: Maximum episodes per video batch, to bound memory use.
|
||||
``None`` means no limit.
|
||||
max_frames_per_batch: Maximum frames per video batch, to bound memory use.
|
||||
``None`` means no limit.
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
A new :class:`LeRobotDataset` with images encoded as videos.
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
if depth_encoder is None:
|
||||
depth_encoder = depth_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
@@ -1706,10 +1730,7 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
)
|
||||
logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1771,6 +1792,8 @@ def convert_image_to_video_dataset(
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
@@ -1779,7 +1802,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=target_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1821,7 +1844,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=target_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1860,16 +1883,11 @@ def convert_image_to_video_dataset(
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
# Update video info for all image keys (now videos). They are registered as
|
||||
# video features above, so update_video_info populates their (still-empty) info.
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
new_meta.update_video_info(video_key=img_key, video_encoder=target_encoder)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
@@ -1896,11 +1914,11 @@ def convert_image_to_video_dataset(
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
video_path, video_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -1909,7 +1927,8 @@ def _reencode_video_worker(args: tuple) -> Path:
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
@@ -1920,8 +1939,11 @@ def reencode_dataset(
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
camera_encoder: Target encoder configuration applied to every RGB video
|
||||
file. If ``None``, re-encoding is skipped for RGB videos.
|
||||
depth_encoder: Target encoder configuration applied to every depth video
|
||||
file. If ``None``, re-encoding is skipped for depth videos.
|
||||
Quantization parameters will not override the ones in the current dataset.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
@@ -1933,23 +1955,35 @@ def reencode_dataset(
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
video_keys_encoders_dict = {}
|
||||
video_keys_paths_dict = {}
|
||||
|
||||
if camera_encoder is None and depth_encoder is None:
|
||||
raise ValueError("Either camera_encoder or depth_encoder must be provided")
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
current_encoder = encoder_config_from_video_info(current_info)
|
||||
target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
|
||||
if target_encoder is None:
|
||||
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
|
||||
elif current_encoder != target_encoder:
|
||||
video_keys_paths_dict[video_key] = list((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
video_keys_encoders_dict[video_key] = target_encoder
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
if len(video_keys_paths_dict) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
worker_args = [
|
||||
(path, encoder, encoder_threads)
|
||||
for video_key, encoder in video_keys_encoders_dict.items()
|
||||
for path in video_keys_paths_dict[video_key]
|
||||
]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
@@ -1963,10 +1997,15 @@ def reencode_dataset(
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
# Refresh video info in metadata for every re-encoded key. Re-encoding only
|
||||
# changes codec/container params, so for depth videos we preserve ``is_depth_map``
|
||||
# and the depth quantization params (``video.depth_min`` / ``video.depth_max`` /
|
||||
# ...), which describe the data rather than the codec and must survive a transcode.
|
||||
# RGB videos pass an empty set: still a refresh, but nothing to preserve.
|
||||
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
|
||||
for video_key, encoder in video_keys_encoders_dict.items():
|
||||
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else set()
|
||||
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
@@ -31,7 +31,12 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
@@ -48,6 +53,7 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
@@ -67,17 +73,22 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
path_template = (
|
||||
DEFAULT_DEPTH_PATH
|
||||
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
|
||||
else DEFAULT_IMAGE_PATH
|
||||
)
|
||||
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -97,6 +108,7 @@ class DatasetWriter:
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
@@ -108,8 +120,11 @@ class DatasetWriter:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
camera_encoder: Video encoder settings applied to RGB cameras. When
|
||||
``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to depth cameras, including
|
||||
the quantization parameters. When ``None``,
|
||||
:func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
@@ -121,6 +136,7 @@ class DatasetWriter:
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -145,7 +161,8 @@ class DatasetWriter:
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -195,6 +212,7 @@ class DatasetWriter:
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self._meta.video_keys),
|
||||
depth_video_keys=list(self._meta.depth_keys),
|
||||
temp_dir=self._root,
|
||||
)
|
||||
|
||||
@@ -282,10 +300,13 @@ class DatasetWriter:
|
||||
if use_streaming:
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self._meta.video_keys:
|
||||
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
k: v
|
||||
if k == "count"
|
||||
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
@@ -300,7 +321,9 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -511,7 +534,12 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
self._meta.update_video_info(
|
||||
video_key,
|
||||
video_encoder=self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -578,13 +606,14 @@ class DatasetWriter:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
|
||||
is_depth = video_key in self._meta.depth_keys
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._depth_encoder if is_depth else self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
video_backend: Video backend to use for encoding. Defaults to "pyav".
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = True,
|
||||
output_channel_last: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Accepted input layouts :
|
||||
|
||||
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
|
||||
- ``(..., 1, H, W)`` — batched frames with channel-first.
|
||||
- ``(..., H, W, 1)`` — batched frames with channel-last.
|
||||
Output layout is determined by ``output_channel_last``.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
|
||||
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
|
||||
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
|
||||
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
|
||||
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
# Compute the scale and offset first.
|
||||
depth_min_m = float(depth_min)
|
||||
depth_max_m = float(depth_max)
|
||||
shift_m = float(shift)
|
||||
if use_log:
|
||||
log_min = math.log(depth_min_m + shift_m)
|
||||
log_max = math.log(depth_max_m + shift_m)
|
||||
scale = (log_max - log_min) / DEPTH_QMAX
|
||||
offset = log_min
|
||||
else:
|
||||
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
|
||||
offset = depth_min_m
|
||||
|
||||
# ── Torch path: stay on the input device, single fp32 allocation. ────────
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
if quantized.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
|
||||
|
||||
# Single allocation we own; everything else is in-place.
|
||||
buf = quantized.to(dtype=torch.float32, copy=True)
|
||||
buf.mul_(scale).add_(offset)
|
||||
if use_log:
|
||||
buf.exp_().sub_(shift_m)
|
||||
buf.clamp_(depth_min_m, depth_max_m)
|
||||
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return buf if output_tensor else buf.cpu().numpy()
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
# when returning a tensor (torch.uint16 is poorly supported).
|
||||
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
if output_tensor:
|
||||
return buf
|
||||
return buf.cpu().numpy().astype(np.uint16, copy=False)
|
||||
|
||||
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
|
||||
arr = np.asarray(quantized)
|
||||
if arr.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
|
||||
|
||||
buf = np.empty(arr.shape, dtype=np.float32)
|
||||
np.multiply(arr, scale, out=buf)
|
||||
np.add(buf, offset, out=buf)
|
||||
if use_log:
|
||||
np.exp(buf, out=buf)
|
||||
np.subtract(buf, shift_m, out=buf)
|
||||
np.clip(buf, depth_min_m, depth_max_m, out=buf)
|
||||
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
np.rint(buf, out=buf)
|
||||
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited; return float32 millimetres.
|
||||
return torch.from_numpy(buf)
|
||||
return buf.astype(np.uint16, copy=False)
|
||||
@@ -96,6 +96,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
depth_output_unit=cfg.dataset.depth_output_unit,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -336,7 +336,7 @@ def validate_feature_image_or_video(
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -42,10 +42,41 @@ def safe_stop_image_writer(func):
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
# TODO(CarolinePascal): 4 dimensions RGB-D images
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in [np.uint16, np.float32]:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
|
||||
)
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
@@ -71,13 +102,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -101,7 +147,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -58,8 +58,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -186,6 +188,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
@@ -208,6 +213,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
@@ -248,6 +254,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
|
||||
# Load actual data
|
||||
@@ -273,6 +280,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
depth_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
@@ -280,6 +288,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -315,6 +324,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
return self.reader
|
||||
|
||||
@@ -322,12 +332,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -646,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -678,6 +691,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
@@ -712,6 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
@@ -721,12 +737,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -750,6 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -779,6 +797,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
@@ -806,6 +826,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
|
||||
if obj._requested_root is not None:
|
||||
@@ -825,12 +846,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
|
||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
exclude_images: bool = False,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||
|
||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||
(image or state), filters them based on `exclude_images` and `patterns`, and finally
|
||||
formats them for use with a Hugging Face LeRobot Dataset.
|
||||
|
||||
Args:
|
||||
pipeline: The DataProcessorPipeline to apply.
|
||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||
use_videos: If False, image features are excluded.
|
||||
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
|
||||
exclude_images: If True, image features are dropped entirely from the output.
|
||||
patterns: A sequence of regex patterns to filter action and state features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
|
||||
)
|
||||
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
if is_image and exclude_images:
|
||||
continue
|
||||
if not is_image and not should_keep(key, compiled_patterns):
|
||||
continue
|
||||
|
||||
@@ -24,6 +24,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +32,22 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_pix_fmt_channels(pix_fmt: str) -> int:
|
||||
"""Return the number of components (channels) for *pix_fmt*."""
|
||||
return len(av.VideoFormat(pix_fmt).components)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -92,7 +109,7 @@ def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Opti
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
num_val = float(value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
@@ -142,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
|
||||
"""Ensure *pix_fmt* can carry at least *channels* components."""
|
||||
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
|
||||
if pix_fmt_channels < channels:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
|
||||
f"but the source data has {channels} channel(s)."
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
@@ -156,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
def check_video_encoder_parameters_pyav(
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, Any],
|
||||
channels: int | None = None,
|
||||
) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
@@ -171,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
if channels is not None:
|
||||
_check_pix_fmt_channels(pix_fmt, channels)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
|
||||
@@ -87,11 +87,14 @@ DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
IMAGE_FILE_PATTERN = "frame-{frame_index:06d}.png"
|
||||
DEPTH_FILE_PATTERN = "frame-{frame_index:06d}.tiff"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/" + IMAGE_FILE_PATTERN
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/" + DEPTH_FILE_PATTERN
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -39,11 +39,16 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,6 +58,7 @@ def decode_video_frames(
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -64,23 +70,35 @@ def decode_video_frames(
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||
torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
|
||||
)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -91,6 +109,7 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
@@ -109,8 +128,9 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
|
||||
Otherwise, return float32 in [0, 1] range.
|
||||
is_depth: Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
@@ -140,9 +160,13 @@ def decode_video_frames_pyav(
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
@@ -185,7 +209,7 @@ def decode_video_frames_pyav(
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
if return_uint8 or is_depth:
|
||||
return closest_frames
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
@@ -406,17 +430,38 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
"""Encode a directory of image frames into an MP4 video.
|
||||
|
||||
When ``video_encoder`` is a :class:`~lerobot.configs.video.DepthEncoderConfig`,
|
||||
frames are read from ``.tiff`` files and quantized to 12-bit depth codes using the
|
||||
encoder's ``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``; otherwise ``.png``
|
||||
RGB frames are encoded directly.
|
||||
|
||||
Args:
|
||||
imgs_dir: Directory containing the frames to encode, named ``frame-000000``
|
||||
onwards (``.png`` for RGB, ``.tiff`` for depth).
|
||||
video_path: Output path for the encoded ``.mp4`` file.
|
||||
fps: Frame rate of the output video.
|
||||
video_encoder: Encoder settings (codec, pixel format, quality, ...). When
|
||||
``None``, :func:`camera_encoder_defaults` is used. Pass a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig` to encode depth frames.
|
||||
encoder_threads: Per-encoder thread count forwarded to the codec. ``None``
|
||||
lets the codec decide.
|
||||
log_level: libav log level to set while encoding, or ``None`` to leave the
|
||||
current logging configuration unchanged.
|
||||
overwrite: When ``False`` and ``video_path`` already exists, skip encoding and
|
||||
log a warning. When ``True``, re-encode and replace the existing file.
|
||||
"""
|
||||
if video_encoder is None:
|
||||
video_encoder = camera_encoder_defaults()
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -428,17 +473,19 @@ def encode_video_frames(
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
suffix = ".png" if not is_depth else ".tiff"
|
||||
template = "frame-" + ("[0-9]" * 6) + suffix
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
raise FileNotFoundError(f"No images with suffix {suffix} found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -455,8 +502,19 @@ def encode_video_frames(
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
with Image.open(input_data) as input_image:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
if is_depth:
|
||||
input_frame = quantize_depth(
|
||||
np.array(input_image),
|
||||
depth_min=video_encoder.depth_min,
|
||||
depth_max=video_encoder.depth_max,
|
||||
shift=video_encoder.shift,
|
||||
use_log=video_encoder.use_log,
|
||||
pix_fmt=video_encoder.pix_fmt,
|
||||
video_backend="pyav",
|
||||
)
|
||||
else:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
@@ -477,7 +535,7 @@ def encode_video_frames(
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
@@ -489,7 +547,7 @@ def reencode_video(
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
video_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
@@ -497,7 +555,7 @@ def reencode_video(
|
||||
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
video_encoder = video_encoder or camera_encoder_defaults()
|
||||
|
||||
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
|
||||
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
|
||||
@@ -512,9 +570,9 @@ def reencode_video(
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
@@ -696,22 +754,21 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
video_encoder: VideoEncoderConfig,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.video_encoder = video_encoder
|
||||
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -736,12 +793,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -749,15 +806,29 @@ class _CameraEncoderThread(threading.Thread):
|
||||
height, width = frame_data.shape[:2]
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream = container.add_stream(
|
||||
self.video_encoder.vcodec,
|
||||
self.fps,
|
||||
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||
)
|
||||
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
if not self.is_depth:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
else:
|
||||
video_frame = quantize_depth(
|
||||
frame_data,
|
||||
depth_min=self.video_encoder.depth_min,
|
||||
depth_max=self.video_encoder.depth_max,
|
||||
shift=self.video_encoder.shift,
|
||||
use_log=self.video_encoder.use_log,
|
||||
video_backend=self.video_encoder.video_backend,
|
||||
)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
@@ -816,21 +887,26 @@ class StreamingVideoEncoder:
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
camera_encoder: Video encoder settings applied to all RGB cameras.
|
||||
When ``None``, :func:`camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
depth_encoder: Video encoder settings applied to all depth cameras,
|
||||
including the depth quantization parameters. When ``None``,
|
||||
:func:`depth_encoder_defaults` is used.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
|
||||
@@ -843,18 +919,25 @@ class StreamingVideoEncoder:
|
||||
self._episode_active = False
|
||||
self._closed = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
def start_episode(
|
||||
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
|
||||
) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
depth_video_keys: List of video or image feature keys that carry depth maps (e.g.
|
||||
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
if depth_video_keys is None:
|
||||
depth_video_keys = []
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
@@ -863,17 +946,15 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
video_encoder=encoder,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self._encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1080,15 +1161,23 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
not already populated from the video file itself. When a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig` is passed, the depth
|
||||
quantization parameters (``depth_min`` / ``depth_max`` / ``shift`` /
|
||||
``use_log``) are recorded so frames can be dequantized on read.
|
||||
|
||||
Returns:
|
||||
The ``video.*`` / ``audio.*`` info dict, including ``is_depth_map`` which is
|
||||
``True`` only when ``video_encoder`` is a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig`.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
@@ -1106,13 +1195,10 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
||||
video_info["video.channels"] = pixel_channels
|
||||
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
@@ -1121,27 +1207,18 @@ def get_video_info(
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
if video_encoder is not None:
|
||||
for field_name, field_value in asdict(video_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
@@ -1202,10 +1279,13 @@ class VideoEncodingManager:
|
||||
img_dir = self.dataset.root / "images"
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
tiff_files = list(img_dir.rglob("*.tiff"))
|
||||
if len(png_files) == 0 and len(tiff_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
logger.debug(
|
||||
f"Images directory is not empty, containing {len(png_files)} PNG and {len(tiff_files)} TIFF files"
|
||||
)
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -126,7 +126,8 @@ def prepare_observation_for_inference(
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
if observation[name].dtype == torch.uint8:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from ..robot import Robot
|
||||
@@ -27,7 +28,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(Robot):
|
||||
class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
@@ -39,15 +40,17 @@ class BiOpenArmFollower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
|
||||
# will only open the cameras assigned to it. Per-arm cameras are used
|
||||
# as fallback when top-level cameras are empty.
|
||||
if config.cameras:
|
||||
left_cameras = config.cameras
|
||||
right_cameras = {}
|
||||
else:
|
||||
left_cameras = config.left_arm_config.cameras
|
||||
right_cameras = config.right_arm_config.cameras
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
@@ -56,7 +59,7 @@ class BiOpenArmFollower(Robot):
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_cameras,
|
||||
cameras=left_arm_cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
@@ -75,7 +78,7 @@ class BiOpenArmFollower(Robot):
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=right_cameras,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
@@ -95,22 +98,19 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
return {
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
|
||||
# "right_wrist"), so we merge them directly — unlike motors which need the
|
||||
# left_/right_ prefix to disambiguate identical per-arm joint names.
|
||||
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -120,27 +120,6 @@ class BiOpenArmFollower(Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
@@ -148,21 +127,15 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Camera keys that should NOT get the arm prefix (they already have unique names)
|
||||
left_cam_keys = set(self.left_arm.cameras.keys())
|
||||
right_cam_keys = set(self.right_arm.cameras.keys())
|
||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||
for key, value in self.left_arm.get_observation().items():
|
||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
right_obs = self.right_arm.get_observation()
|
||||
for key, value in right_obs.items():
|
||||
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
||||
|
||||
left_obs = self.left_arm.get_observation()
|
||||
for key, value in left_obs.items():
|
||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||
# Add "right_" prefix
|
||||
for key, value in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{key}"] = value
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -189,9 +162,4 @@ class BiOpenArmFollower(Robot):
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@@ -32,5 +32,7 @@ class BiOpenArmFollowerConfig(RobotConfig):
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
|
||||
# Top-level cameras shared across both arms.
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
||||
from ..robot import Robot
|
||||
@@ -27,7 +28,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiRebotB601Follower(Robot):
|
||||
class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
"""Bimanual Seeed Studio reBot B601-DM follower.
|
||||
|
||||
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
||||
@@ -41,6 +42,18 @@ class BiRebotB601Follower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = RebotB601FollowerRobotConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
@@ -49,7 +62,7 @@ class BiRebotB601Follower(Robot):
|
||||
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
cameras=left_arm_cameras,
|
||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||
@@ -86,10 +99,12 @@ class BiRebotB601Follower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
|
||||
}
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -99,32 +114,13 @@ class BiRebotB601Follower(Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
|
||||
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
|
||||
obs_dict: RobotObservation = {}
|
||||
for k, v in self.left_arm.get_observation().items():
|
||||
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{k}"] = v
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -143,8 +139,3 @@ class BiRebotB601Follower(Robot):
|
||||
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
||||
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
from ..rebot_b601_follower import RebotB601FollowerConfig
|
||||
@@ -27,3 +29,8 @@ class BiRebotB601FollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: RebotB601FollowerConfig
|
||||
right_arm_config: RebotB601FollowerConfig
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
||||
@@ -27,7 +28,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiSOFollower(Robot):
|
||||
class BiSOFollower(BimanualMixin, Robot):
|
||||
"""
|
||||
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||
"""
|
||||
@@ -39,6 +40,18 @@ class BiSOFollower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = SOFollowerRobotConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
@@ -46,7 +59,7 @@ class BiSOFollower(Robot):
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
use_degrees=config.left_arm_config.use_degrees,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
cameras=left_arm_cameras,
|
||||
)
|
||||
|
||||
right_arm_config = SOFollowerRobotConfig(
|
||||
@@ -77,13 +90,12 @@ class BiSOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -93,42 +105,21 @@ class BiSOFollower(Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||
for key, value in self.left_arm.get_observation().items():
|
||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
for key, value in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{key}"] = value
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -151,8 +142,3 @@ class BiSOFollower(Robot):
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
from ..so_follower import SOFollowerConfig
|
||||
@@ -27,3 +29,8 @@ class BiSOFollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: SOFollowerConfig
|
||||
right_arm_config: SOFollowerConfig
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -68,9 +68,12 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
|
||||
if getattr(self.cameras[cam], "use_depth", False):
|
||||
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
|
||||
return features
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -190,6 +193,12 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
if getattr(cam, "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -333,6 +333,7 @@ def build_rollout_context(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
@@ -368,6 +369,7 @@ def build_rollout_context(
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
|
||||
@@ -54,6 +54,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
|
||||
@@ -133,6 +133,15 @@ Convert image dataset to video format and save locally:
|
||||
--new_root /path/to/output/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset (with depth maps) to video format, customizing the depth encoder:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_root /path/to/output/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.depth_encoder.depth_min 0.01 \
|
||||
--operation.depth_encoder.depth_max 10.0 \
|
||||
--operation.depth_encoder.use_log true
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -211,6 +220,13 @@ Re-encode videos in-place (overwrites original dataset):
|
||||
--operation.camera_encoder.vcodec h264 \
|
||||
--operation.overwrite true
|
||||
|
||||
Re-encode both RGB and depth videos in a dataset (depth quantization params are preserved):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_depth \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec libx264 \
|
||||
--operation.depth_encoder.vcodec ffv1
|
||||
|
||||
Using JSON config file:
|
||||
lerobot-edit-dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
@@ -225,7 +241,13 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, parser
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
parser,
|
||||
)
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
convert_image_to_video_dataset,
|
||||
@@ -288,6 +310,7 @@ class ModifyTasksConfig(OperationConfig):
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
output_dir: str | None = None
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
@@ -309,6 +332,7 @@ class RecomputeStatsConfig(OperationConfig):
|
||||
@dataclass
|
||||
class ReencodeVideosConfig(OperationConfig):
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
num_workers: int = 0
|
||||
encoder_threads: int | None = None
|
||||
overwrite: bool = False
|
||||
@@ -602,6 +626,7 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder=getattr(cfg.operation, "camera_encoder", None) or camera_encoder_defaults(),
|
||||
depth_encoder=getattr(cfg.operation, "depth_encoder", None) or depth_encoder_defaults(),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
@@ -719,10 +744,14 @@ def handle_reencode_videos(cfg: EditDatasetConfig) -> None:
|
||||
shutil.copytree(input_root, output_root)
|
||||
dataset = LeRobotDataset(output_repo_id, root=output_root)
|
||||
|
||||
logging.info(f"Re-encoding videos in {output_repo_id} with {cfg.operation.camera_encoder}")
|
||||
logging.info(
|
||||
f"Re-encoding videos in {output_repo_id} with RGB encoder {cfg.operation.camera_encoder} "
|
||||
f"and depth encoder {cfg.operation.depth_encoder}"
|
||||
)
|
||||
reencode_dataset(
|
||||
dataset,
|
||||
camera_encoder=cfg.operation.camera_encoder,
|
||||
depth_encoder=cfg.operation.depth_encoder,
|
||||
encoder_threads=cfg.operation.encoder_threads,
|
||||
num_workers=cfg.operation.num_workers,
|
||||
)
|
||||
|
||||
@@ -57,6 +57,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
|
||||
@@ -137,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
@@ -403,6 +404,7 @@ def record(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
@@ -432,6 +434,7 @@ def record(
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
|
||||
@@ -174,6 +174,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
|
||||
@@ -41,6 +41,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
koch_leader,
|
||||
|
||||
@@ -89,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -27,7 +28,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmLeader(Teleoperator):
|
||||
class BiOpenArmLeader(BimanualMixin, Teleoperator):
|
||||
"""
|
||||
Bimanual OpenArm Leader Arms
|
||||
"""
|
||||
@@ -86,27 +87,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
@@ -129,8 +109,3 @@ class BiOpenArmLeader(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||
@dataclass
|
||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
"""Configuration class for Bi OpenArm Leader teleoperators."""
|
||||
|
||||
left_arm_config: OpenArmLeaderConfigBase
|
||||
right_arm_config: OpenArmLeaderConfigBase
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_mini import BiOpenArmMini
|
||||
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
||||
|
||||
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
|
||||
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmMini(BimanualMixin, Teleoperator):
|
||||
"""Bimanual OpenArm Mini teleoperator.
|
||||
|
||||
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
|
||||
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
|
||||
bimanual leader can teleoperate a bimanual OpenArm follower.
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmMiniConfig
|
||||
name = "bi_openarm_mini"
|
||||
|
||||
def __init__(self, config: BiOpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# `side` is forced to match left/right regardless of what the user passed
|
||||
# on the per-arm base config — the bimanual wrapper owns the side semantics.
|
||||
left_arm_config = OpenArmMiniConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
side="left",
|
||||
use_degrees=config.left_arm_config.use_degrees,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmMiniConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
side="right",
|
||||
use_degrees=config.right_arm_config.use_degrees,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmMini(left_arm_config)
|
||||
self.right_arm = OpenArmMini(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
|
||||
}
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action: RobotAction = {}
|
||||
for k, v in self.left_arm.get_action().items():
|
||||
action[f"left_{k}"] = v
|
||||
for k, v in self.right_arm.get_action().items():
|
||||
action[f"right_{k}"] = v
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
|
||||
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
|
||||
if left_fb:
|
||||
self.left_arm.send_feedback(left_fb)
|
||||
if right_fb:
|
||||
self.right_arm.send_feedback(right_fb)
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
from ..openarm_mini import OpenArmMiniConfigBase
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
|
||||
@dataclass
|
||||
class BiOpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Mini teleoperators."""
|
||||
|
||||
left_arm_config: OpenArmMiniConfigBase
|
||||
right_arm_config: OpenArmMiniConfigBase
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_rebot_102_leader import BiRebotArm102Leader
|
||||
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
||||
from .bi_rebot_102_leader import BiRebot102Leader
|
||||
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
||||
|
||||
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
|
||||
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
|
||||
|
||||
@@ -18,16 +18,17 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
||||
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiRebotArm102Leader(Teleoperator):
|
||||
class BiRebot102Leader(BimanualMixin, Teleoperator):
|
||||
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
|
||||
|
||||
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
|
||||
@@ -35,10 +36,10 @@ class BiRebotArm102Leader(Teleoperator):
|
||||
leader can teleoperate a bimanual reBot B601 follower.
|
||||
"""
|
||||
|
||||
config_class = BiRebotArm102LeaderConfig
|
||||
config_class = BiRebot102LeaderConfig
|
||||
name = "bi_rebot_102_leader"
|
||||
|
||||
def __init__(self, config: BiRebotArm102LeaderConfig):
|
||||
def __init__(self, config: BiRebot102LeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
@@ -76,27 +77,6 @@ class BiRebotArm102Leader(Teleoperator):
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
@@ -106,8 +86,3 @@ class BiRebotArm102Leader(Teleoperator):
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
|
||||
@dataclass
|
||||
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
|
||||
class BiRebot102LeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
|
||||
|
||||
left_arm_config: RebotArm102LeaderConfig
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader, SOLeaderTeleopConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -26,7 +28,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiSOLeader(Teleoperator):
|
||||
class BiSOLeader(BimanualMixin, Teleoperator):
|
||||
"""
|
||||
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||
"""
|
||||
@@ -67,33 +69,12 @@ class BiSOLeader(Teleoperator):
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -109,8 +90,3 @@ class BiSOLeader(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
|
||||
|
||||
@@ -19,12 +19,21 @@ from dataclasses import dataclass
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
class OpenArmMiniConfigBase:
|
||||
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
|
||||
port: str
|
||||
|
||||
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
|
||||
# during readout. If `None`, no flipping is applied.
|
||||
side: str | None = None
|
||||
|
||||
use_degrees: bool = True
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
|
||||
pass
|
||||
|
||||
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
# Per-side motor direction flips applied during readout.
|
||||
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
|
||||
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
|
||||
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
|
||||
}
|
||||
|
||||
# Leader joint 6 maps to follower joint 7 and vice versa
|
||||
# Leader joint 6 ↔ follower joint 7 (symmetric — its own inverse).
|
||||
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
||||
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
|
||||
|
||||
GRIPPER_TELEOP_TO_DEGREES = -0.65
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
@@ -56,9 +56,12 @@ class OpenArmMini(Teleoperator):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
|
||||
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
|
||||
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
motors = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
@@ -69,46 +72,15 @@ class OpenArmMini(Teleoperator):
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
@@ -116,14 +88,12 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
@@ -133,14 +103,14 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
Run calibration procedure for a single OpenArm Mini arm.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
2. Ask user to position arm in hanging position with gripper closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
@@ -152,70 +122,51 @@ class OpenArmMini(Teleoperator):
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
self.bus.disable_torque()
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
logger.info("Setting Phase to 12 for all motors...")
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Phase", motor, 12)
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"\nCalibration: Zero Position\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
homing_offsets = self.bus.set_half_turn_homings()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
print("\nSetting motor ranges\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
"\nGripper Calibration\n"
|
||||
"Step 1: CLOSE the gripper fully\n"
|
||||
"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
@@ -228,16 +179,16 @@ class OpenArmMini(Teleoperator):
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f" {motor_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
@@ -245,108 +196,68 @@ class OpenArmMini(Teleoperator):
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
self.bus.disable_torque()
|
||||
self.bus.configure_motors()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
"""Get current action (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
positions = self.bus.sync_read("Present_Position")
|
||||
|
||||
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
|
||||
# Per-side direction flip is applied based on the configured `side`.
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
for motor, val in positions.items():
|
||||
target = JOINT_REMAP.get(motor, motor)
|
||||
if motor == "gripper":
|
||||
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
|
||||
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
target = JOINT_REMAP.get(motor, motor)
|
||||
if motor == "gripper":
|
||||
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def enable_torque(self) -> None:
|
||||
"""Enable torque on both arms for position control."""
|
||||
self.bus_right.enable_torque()
|
||||
self.bus_left.enable_torque()
|
||||
self.bus.enable_torque()
|
||||
|
||||
def disable_torque(self) -> None:
|
||||
"""Disable torque on both arms for free movement."""
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_left.disable_torque()
|
||||
self.bus.disable_torque()
|
||||
|
||||
def write_goal_positions(self, positions: dict[str, float]) -> None:
|
||||
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
|
||||
right_goals: dict[str, float] = {}
|
||||
left_goals: dict[str, float] = {}
|
||||
|
||||
goals: dict[str, float] = {}
|
||||
for key, val in positions.items():
|
||||
if not key.endswith(".pos"):
|
||||
continue
|
||||
motor_name = key.removesuffix(".pos")
|
||||
if motor_name.startswith("right_"):
|
||||
base = motor_name.removeprefix("right_")
|
||||
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
|
||||
target = JOINT_REMAP_REVERSE.get(base, base)
|
||||
if base == "gripper":
|
||||
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
||||
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
||||
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
|
||||
elif motor_name.startswith("left_"):
|
||||
base = motor_name.removeprefix("left_")
|
||||
target = JOINT_REMAP_REVERSE.get(base, base)
|
||||
if base == "gripper":
|
||||
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
|
||||
base = key.removesuffix(".pos")
|
||||
# JOINT_REMAP is symmetric (its own inverse).
|
||||
target = JOINT_REMAP.get(base, base)
|
||||
if base == "gripper":
|
||||
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
||||
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
||||
goals[target] = -val if target in self._motors_to_flip else val
|
||||
|
||||
if right_goals:
|
||||
self.bus_right.sync_write("Goal_Position", right_goals)
|
||||
if left_goals:
|
||||
self.bus_left.sync_write("Goal_Position", left_goals)
|
||||
if goals:
|
||||
self.bus.sync_write("Goal_Position", goals)
|
||||
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
@@ -354,6 +265,5 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -99,14 +99,18 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
elif config.type == "bi_openarm_mini":
|
||||
from .bi_openarm_mini import BiOpenArmMini
|
||||
|
||||
return BiOpenArmMini(config)
|
||||
elif config.type == "rebot_102_leader":
|
||||
from .rebot_102_leader import RebotArm102Leader
|
||||
|
||||
return RebotArm102Leader(config)
|
||||
elif config.type == "bi_rebot_102_leader":
|
||||
from .bi_rebot_102_leader import BiRebotArm102Leader
|
||||
from .bi_rebot_102_leader import BiRebot102Leader
|
||||
|
||||
return BiRebotArm102Leader(config)
|
||||
return BiRebot102Leader(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
|
||||
class BimanualMixin:
|
||||
"""Lifecycle delegation for bimanual robots and teleoperators.
|
||||
|
||||
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
|
||||
their own ``__init__``. They retain ownership of feature dicts and the
|
||||
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
|
||||
``send_feedback``), which vary per-embodiment.
|
||||
|
||||
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
|
||||
take precedence in the MRO::
|
||||
|
||||
class BiFooFollower(BimanualMixin, Robot): ...
|
||||
"""
|
||||
|
||||
left_arm: Any
|
||||
right_arm: Any
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -51,7 +51,9 @@ def hw_to_dataset_features(
|
||||
|
||||
This function takes a dictionary describing hardware outputs (like joint states
|
||||
or camera image shapes) and formats it into the standard LeRobot feature
|
||||
specification.
|
||||
specification. Single-channel cameras (shape ``(H, W, 1)``) are flagged as depth
|
||||
maps via ``info["is_depth_map"] = True``; three-channel cameras ``(H, W, 3)`` are
|
||||
treated as RGB.
|
||||
|
||||
Args:
|
||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||
@@ -61,7 +63,7 @@ def hw_to_dataset_features(
|
||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||
|
||||
Returns:
|
||||
dict: A LeRobot features dictionary.
|
||||
dict: A LeRobot features dictionary. Depth cameras carry ``info["is_depth_map"] = True``.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {
|
||||
@@ -69,6 +71,7 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
# TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
@@ -86,11 +89,19 @@ def hw_to_dataset_features(
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
dtype = "video" if use_video else "image"
|
||||
if len(shape) == 3 and shape[2] in (1, 3):
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": shape[2] == 1},
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Camera feature '{key}' has shape {shape}. "
|
||||
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
|
||||
)
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -149,11 +160,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
else:
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
|
||||
@@ -107,8 +107,15 @@ def log_rerun_data(
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
if arr.shape[-1] == 1:
|
||||
img_entity = (
|
||||
rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis).compress()
|
||||
if compress_images
|
||||
else rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
|
||||
)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity)
|
||||
|
||||
if action:
|
||||
for k, v in action.items():
|
||||
|
||||
@@ -208,14 +208,14 @@ def test_episode_clip_path_trims_via_reencode_video(tmp_path: Path, monkeypatch)
|
||||
def fake_reencode(
|
||||
input_video_path,
|
||||
output_video_path,
|
||||
camera_encoder=None,
|
||||
video_encoder=None,
|
||||
overwrite=False,
|
||||
start_time_s=None,
|
||||
end_time_s=None,
|
||||
):
|
||||
captured.update(
|
||||
src=Path(input_video_path),
|
||||
encoder=camera_encoder,
|
||||
encoder=video_encoder,
|
||||
start_time_s=start_time_s,
|
||||
end_time_s=end_time_s,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,10 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
@@ -191,6 +194,26 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
pass
|
||||
|
||||
|
||||
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
|
||||
"""Test that depth keys are correctly preserved after aggregation.
|
||||
|
||||
Ensures that the ``is_depth_map`` marker on visual features survives
|
||||
aggregation, so that downstream consumers (e.g. the dataset reader's
|
||||
depth decoding path) keep working on the merged dataset.
|
||||
"""
|
||||
expected_depth_keys = set(ds_0.meta.depth_keys)
|
||||
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
|
||||
"Source datasets disagree on depth_keys; test setup is inconsistent"
|
||||
)
|
||||
actual_depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
assert actual_depth_keys == expected_depth_keys, (
|
||||
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
|
||||
)
|
||||
for key in expected_depth_keys:
|
||||
info = aggr_ds.meta.info.features[key].get("info") or {}
|
||||
assert info.get("is_depth_map") is True, f"Depth marker lost on feature {key!r} after aggregation"
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||
|
||||
@@ -240,7 +263,11 @@ def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
"""Test basic aggregation functionality with standard parameters.
|
||||
|
||||
Source datasets include both RGB and depth video features so the same
|
||||
aggregation flow is exercised on the ``is_depth_map`` branch.
|
||||
"""
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 800
|
||||
ds_0_num_episodes = 10
|
||||
@@ -252,14 +279,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Confirm depth was actually wired into the source datasets so the
|
||||
# rest of the assertions exercise the depth aggregation path.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
@@ -286,6 +320,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@@ -403,7 +438,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding.
|
||||
|
||||
Depth video features are included to verify that file rotation/concat
|
||||
correctly handles depth-marked features alongside regular RGB ones.
|
||||
"""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
ds_0_num_frames = ds_1_num_frames = 400
|
||||
|
||||
@@ -412,14 +451,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Use the new configurable parameters to force file rotation
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
@@ -450,6 +494,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
@@ -469,7 +514,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for video timestamp bug when merging datasets.
|
||||
|
||||
This test specifically checks that video timestamps are correctly calculated
|
||||
and accumulated when merging multiple datasets.
|
||||
and accumulated when merging multiple datasets. Depth video features are
|
||||
included so depth timestamps are also covered by the regression.
|
||||
"""
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
@@ -478,9 +524,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
for i, ds in enumerate(datasets):
|
||||
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds.repo_id for ds in datasets],
|
||||
roots=[ds.root for ds in datasets],
|
||||
@@ -497,12 +547,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
# Depth keys must survive the merge for the regression to cover the
|
||||
# ``is_depth_map`` decoding branch.
|
||||
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
|
||||
|
||||
depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
for i in range(len(aggr_ds)):
|
||||
item = aggr_ds[i]
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
# Depth frames are single-channel (1, H, W) after dequantization;
|
||||
# standard RGB frames keep the 3-channel layout.
|
||||
expected_channels = 1 if key in depth_keys else 3
|
||||
assert item[key].shape[0] == expected_channels, (
|
||||
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
|
||||
)
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
@@ -584,25 +643,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
# Create two image-based datasets (use_videos=False) with a mix of RGB
|
||||
# and depth-marked cameras so the depth path is exercised in image mode.
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
# And that the depth marker actually made it onto an image feature.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
@@ -637,6 +702,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
|
||||
@@ -59,11 +59,13 @@ def _make_dummy_stats(features: dict) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
channels = ft["shape"][-1]
|
||||
stat_shape = (channels, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.ones((3, 1, 1), dtype=np.float32),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
|
||||
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
|
||||
"max": np.ones(stat_shape, dtype=np.float32),
|
||||
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
|
||||
"min": np.zeros(stat_shape, dtype=np.float32),
|
||||
"std": np.full(stat_shape, 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
elif ft["dtype"] in ("float32", "float64", "int64"):
|
||||
@@ -142,6 +144,45 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("marker_field", "marker_key"),
|
||||
[
|
||||
("info", "is_depth_map"),
|
||||
("info", "video.is_depth_map"),
|
||||
("video_info", "video.is_depth_map"),
|
||||
],
|
||||
ids=["info.is_depth_map", "info.video.is_depth_map_legacy", "video_info.video.is_depth_map_legacy"],
|
||||
)
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path, marker_field, marker_key):
|
||||
"""``depth_keys`` recognises the canonical and the two legacy marker variants."""
|
||||
depth_feature = {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
marker_field: {marker_key: True},
|
||||
}
|
||||
features = {
|
||||
**VIDEO_FEATURES,
|
||||
"observation.images.laptop_depth": depth_feature,
|
||||
}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/depth_keys",
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / f"depth_keys_{marker_field}_{marker_key.replace('.', '_')}",
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
|
||||
assert meta.depth_keys == ["observation.images.laptop_depth"]
|
||||
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
|
||||
)
|
||||
assert meta.depth_keys == []
|
||||
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
convert_image_to_video_dataset,
|
||||
@@ -37,7 +37,9 @@ from lerobot.datasets.dataset_tools import (
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.io_utils import load_info
|
||||
from tests.datasets.test_video_encoding import _add_frames, require_h264, require_libsvtav1
|
||||
from tests.datasets.test_video_encoding import require_h264, require_hevc, require_libsvtav1
|
||||
from tests.fixtures.constants import DUMMY_DEPTH_FEATURES, DUMMY_DEPTH_KEY
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1332,9 +1334,131 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
@require_libsvtav1
|
||||
@require_hevc
|
||||
def test_convert_image_to_video_dataset_depth(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Depth image features convert to depth videos using the depth encoder.
|
||||
|
||||
Mirrors :func:`test_convert_image_to_video_dataset` but with a small local
|
||||
image dataset that mixes an RGB camera with a depth camera, so the
|
||||
``depth_keys`` → ``depth_encoder`` routing and ``is_depth_map`` preservation
|
||||
are exercised end-to-end.
|
||||
"""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
"observation.images.cam": {
|
||||
"dtype": "image",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.depth": {
|
||||
"dtype": "image",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": True},
|
||||
},
|
||||
}
|
||||
source_dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "img_ds",
|
||||
features=features,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
add_frames(source_dataset, num_frames=4)
|
||||
source_dataset.save_episode()
|
||||
source_dataset.finalize()
|
||||
|
||||
# Source is an image dataset with the depth marker on the depth camera.
|
||||
assert len(source_dataset.meta.video_keys) == 0
|
||||
assert "observation.images.depth" in source_dataset.meta.depth_keys
|
||||
|
||||
output_dir = tmp_path / "video_ds"
|
||||
with (
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
# Use non-default quantization params so the persisted metadata must
|
||||
# come from the depth encoder (not RGB encoder defaults).
|
||||
depth_encoder = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=2,
|
||||
crf=30,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.0,
|
||||
use_log=False,
|
||||
)
|
||||
video_dataset = convert_image_to_video_dataset(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="dummy/depth_video",
|
||||
camera_encoder=VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30),
|
||||
depth_encoder=depth_encoder,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
# Both cameras are now videos, and the depth marker survived the conversion.
|
||||
assert "observation.images.cam" in video_dataset.meta.video_keys
|
||||
assert "observation.images.depth" in video_dataset.meta.video_keys
|
||||
assert "observation.images.depth" in video_dataset.meta.depth_keys
|
||||
assert "observation.images.cam" not in video_dataset.meta.depth_keys
|
||||
|
||||
depth_path = video_dataset.root / video_dataset.meta.get_video_file_path(0, "observation.images.depth")
|
||||
assert depth_path.exists(), f"Depth video file should exist: {depth_path}"
|
||||
|
||||
# The persisted depth-video metadata must carry the depth quantization params
|
||||
# from the depth encoder (so frames dequantize correctly on read), and the RGB
|
||||
# camera must not be marked as a depth map.
|
||||
persisted_info = load_info(video_dataset.root)
|
||||
depth_info = persisted_info.features["observation.images.depth"]["info"]
|
||||
assert depth_info["is_depth_map"] is True
|
||||
assert DepthEncoderConfig.from_video_info(depth_info) == depth_encoder
|
||||
|
||||
cam_info = persisted_info.features["observation.images.cam"]["info"]
|
||||
assert cam_info.get("is_depth_map") is False
|
||||
assert "video.codec" in cam_info
|
||||
|
||||
|
||||
# ─── reencode_dataset ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@require_hevc
|
||||
def test_reencode_dataset_depth_uses_depth_encoder(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Depth videos are re-encoded with the depth encoder and keep their depth metadata.
|
||||
|
||||
Depth-focused companion to :func:`test_reencode_dataset_multi_key_multiprocessing`.
|
||||
"""
|
||||
initial_cfg = DepthEncoderConfig(vcodec="hevc", pix_fmt="gray12le", g=2, crf=30)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
features=DUMMY_DEPTH_FEATURES,
|
||||
use_videos=True,
|
||||
depth_encoder=initial_cfg,
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert DUMMY_DEPTH_KEY in dataset.meta.depth_keys
|
||||
|
||||
target_cfg = DepthEncoderConfig(vcodec="hevc", pix_fmt="gray12le", g=6, crf=23)
|
||||
result = reencode_dataset(dataset, depth_encoder=target_cfg, num_workers=0)
|
||||
|
||||
assert result is dataset
|
||||
|
||||
persisted_info = load_info(dataset.root)
|
||||
depth_info = persisted_info.features[DUMMY_DEPTH_KEY].get("info", {})
|
||||
# Re-encode applied the new codec parameters to the depth video ...
|
||||
assert DepthEncoderConfig.from_video_info(depth_info) == target_cfg
|
||||
# ... while preserving the depth marker.
|
||||
assert depth_info["is_depth_map"] is True
|
||||
|
||||
|
||||
@require_libsvtav1
|
||||
@require_h264
|
||||
def test_reencode_dataset_multi_key_multiprocessing(
|
||||
@@ -1350,9 +1474,9 @@ def test_reencode_dataset_multi_key_multiprocessing(
|
||||
camera_encoder=initial_cfg,
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_camera_encoder(tmp_path):
|
||||
"""_encode_video_worker forwards camera_encoder to encode_video_frames."""
|
||||
def test_encode_video_worker_forwards_video_encoder(tmp_path):
|
||||
"""_encode_video_worker forwards video_encoder to encode_video_frames."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -74,16 +74,16 @@ def test_encode_video_worker_forwards_camera_encoder(tmp_path):
|
||||
0,
|
||||
tmp_path,
|
||||
fps=30,
|
||||
camera_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
video_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
encoder_threads=4,
|
||||
)
|
||||
|
||||
assert captured_kwargs["camera_encoder"].vcodec == "h264"
|
||||
assert captured_kwargs["video_encoder"].vcodec == "h264"
|
||||
assert captured_kwargs["encoder_threads"] == 4
|
||||
|
||||
|
||||
def test_encode_video_worker_default_camera_encoder(tmp_path):
|
||||
"""_encode_video_worker passes None camera_encoder which encode_video_frames defaults."""
|
||||
def test_encode_video_worker_default_video_encoder(tmp_path):
|
||||
"""_encode_video_worker passes None video_encoder which encode_video_frames defaults."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -100,7 +100,7 @@ def test_encode_video_worker_default_camera_encoder(tmp_path):
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||
|
||||
assert captured_kwargs["camera_encoder"] is None
|
||||
assert captured_kwargs["video_encoder"] is None
|
||||
assert captured_kwargs["encoder_threads"] is None
|
||||
|
||||
|
||||
|
||||
@@ -1516,10 +1516,15 @@ def test_valid_video_codecs_constant():
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert "h264_vaapi" in VALID_VIDEO_CODECS
|
||||
assert "h264_qsv" in VALID_VIDEO_CODECS
|
||||
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "hevc_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Tests for the depth-integration feature.
|
||||
|
||||
Covers:
|
||||
- ``depth_utils`` quantize/dequantize round-trips and backend agreement.
|
||||
- Image-writer support for single-channel depth.
|
||||
- Hardware-feature → depth flag routing.
|
||||
- Feature-to-file-format routing through the dataset writer.
|
||||
|
||||
Depth metadata detection on ``LeRobotDatasetMetadata.depth_keys`` lives in
|
||||
``test_dataset_metadata.py``. Depth video encoding/decoding lives in
|
||||
``test_video_encoding.py``.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig
|
||||
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
|
||||
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_CHW,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
_, H, W = DUMMY_CHW
|
||||
|
||||
|
||||
def _depth_metres_ramp() -> np.ndarray:
|
||||
"""Linearly-spaced float32 depth in metres covering the default range."""
|
||||
return np.linspace(DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
|
||||
|
||||
|
||||
# ── 1. Quantize / dequantize round-trips ──────────────────────────────
|
||||
|
||||
|
||||
class TestQuantizeDequantize:
|
||||
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
@pytest.mark.parametrize("output_channel_last", [False, True])
|
||||
def test_roundtrip(self, use_log, output_unit, output_channel_last):
|
||||
"""quantize → dequantize recovers depth; layout and unit are honored."""
|
||||
depth = _depth_metres_ramp()
|
||||
quantized = quantize_depth(depth, use_log=use_log, video_backend=None)
|
||||
recovered = dequantize_depth(
|
||||
quantized,
|
||||
use_log=use_log,
|
||||
output_unit=output_unit,
|
||||
output_tensor=False,
|
||||
output_channel_last=output_channel_last,
|
||||
)
|
||||
|
||||
expected_shape = (H, W, 1) if output_channel_last else (1, H, W)
|
||||
assert recovered.shape == expected_shape
|
||||
|
||||
recovered_m = recovered.astype(np.float32)
|
||||
if output_unit == "mm":
|
||||
recovered_m = recovered_m / 1000.0
|
||||
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
|
||||
|
||||
if use_log:
|
||||
# Log mode: tighter near-range error than far-range (the whole point).
|
||||
near = depth < 1.0
|
||||
far = depth > 8.0
|
||||
err_near = np.abs(recovered_2d[near] - depth[near])
|
||||
err_far = np.abs(recovered_2d[far] - depth[far])
|
||||
assert err_near.mean() < err_far.mean()
|
||||
else:
|
||||
# Linear mode: bounded by quant step + 1 mm of unit-conversion rounding.
|
||||
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
|
||||
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
def test_numpy_torch_agree(self, use_log, output_unit):
|
||||
"""Batched torch path produces the same values as the numpy path."""
|
||||
batch_size = 3
|
||||
per_frame = np.linspace(0, DEPTH_QMAX, H * W, dtype=np.uint16).reshape(H, W)
|
||||
batch_np = np.broadcast_to(per_frame[None, None, ...], (batch_size, 1, H, W)).copy()
|
||||
batch_t = torch.from_numpy(batch_np.astype(np.int32)) # torch.uint16 support is patchy.
|
||||
|
||||
ref = dequantize_depth(batch_np, use_log=use_log, output_unit=output_unit, output_tensor=False)
|
||||
out = dequantize_depth(batch_t, use_log=use_log, output_unit=output_unit, output_tensor=True)
|
||||
|
||||
assert isinstance(out, torch.Tensor)
|
||||
assert out.shape == (batch_size, 1, H, W)
|
||||
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
|
||||
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
|
||||
atol = 1e-5 if output_unit == "m" else 1.0
|
||||
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((H, W), (1, H, W)),
|
||||
((1, H, W), (1, H, W)),
|
||||
((H, W, 1), (1, H, W)),
|
||||
((3, 1, H, W), (3, 1, H, W)),
|
||||
((3, H, W, 1), (3, 1, H, W)),
|
||||
],
|
||||
)
|
||||
def test_input_layouts_accepted(self, input_shape, output_shape):
|
||||
"""All documented input layouts decode to the channel-first default."""
|
||||
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
|
||||
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
|
||||
assert out.shape == output_shape
|
||||
|
||||
def test_pyav_frame_roundtrip(self):
|
||||
"""quantize → av.VideoFrame → dequantize works."""
|
||||
depth = _depth_metres_ramp()
|
||||
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
|
||||
assert isinstance(frame, av.VideoFrame)
|
||||
|
||||
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
|
||||
assert recovered.shape == (1, H, W)
|
||||
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
|
||||
np.testing.assert_allclose(recovered[0], depth, atol=tol)
|
||||
|
||||
def test_invalid_log_params_raises(self):
|
||||
with pytest.raises(ValueError, match=r"depth_min \+ shift must be positive"):
|
||||
quantize_depth(_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
|
||||
|
||||
|
||||
# ── 2. Image writer depth support ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestImageWriterDepth:
|
||||
"""``image_array_to_pil_image`` and ``write_image`` for depth maps."""
|
||||
|
||||
@pytest.mark.parametrize("dtype,expected_mode", [(np.uint16, "I;16"), (np.float32, "F")])
|
||||
@pytest.mark.parametrize("shape", [(H, W), (H, W, 1), (1, H, W)])
|
||||
def test_pil_depth_modes_and_squeeze(self, dtype, expected_mode, shape):
|
||||
"""Single-channel depth converts to PIL with the right mode and (W, H) size."""
|
||||
arr = np.zeros(shape, dtype=dtype)
|
||||
img = image_array_to_pil_image(arr)
|
||||
assert img.mode == expected_mode
|
||||
assert img.size == (W, H)
|
||||
|
||||
def test_write_image_tiff_roundtrip(self, tmp_path):
|
||||
"""uint16 depth round-trips through .tiff."""
|
||||
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(arr, fpath)
|
||||
with PIL.Image.open(fpath) as loaded:
|
||||
recovered = np.array(loaded)
|
||||
np.testing.assert_array_equal(recovered, arr)
|
||||
|
||||
|
||||
# ── 3. Hardware-feature → depth flag ──────────────────────────────────
|
||||
|
||||
|
||||
class TestHwToDatasetFeaturesDepth:
|
||||
"""``hw_to_dataset_features`` flags single-channel cameras as depth."""
|
||||
|
||||
@pytest.mark.parametrize("channels,is_depth", [(1, True), (3, False)])
|
||||
def test_depth_marker_by_channels(self, channels, is_depth):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
features = hw_to_dataset_features({"cam": (480, 640, channels)}, prefix="observation")
|
||||
assert features["observation.images.cam"]["info"]["is_depth_map"] is is_depth
|
||||
|
||||
def test_invalid_channel_count_raises(self):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
with pytest.raises(ValueError, match="Expected a 3-tuple"):
|
||||
hw_to_dataset_features({"cam": (480, 640, 2)}, prefix="observation")
|
||||
|
||||
|
||||
# ── 4. Feature-to-file-format routing ────────────────────────────────
|
||||
|
||||
|
||||
# Keys derived from DUMMY_CAMERA_FEATURES_WITH_DEPTH; pick one RGB and the depth camera.
|
||||
RGB_KEY = next(iter(DUMMY_CAMERA_FEATURES))
|
||||
DEPTH_KEY = next(iter(DUMMY_DEPTH_CAMERA_FEATURES))
|
||||
|
||||
|
||||
class TestFeatureFileRouting:
|
||||
"""Depth vs RGB features route to the correct file format."""
|
||||
|
||||
NUM_FRAMES = 5
|
||||
|
||||
def test_image_mode_depth_tiff_rgb_png(self, tmp_path, features_factory):
|
||||
"""Without video encoding: depth → .tiff, RGB → .png."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=False)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / "ds",
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=self.NUM_FRAMES)
|
||||
|
||||
buf = dataset.writer.episode_buffer
|
||||
assert all(Path(p).suffix == ".tiff" for p in buf[DEPTH_KEY])
|
||||
assert all(Path(p).suffix == ".png" for p in buf[RGB_KEY])
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
def test_video_mode_depth_uses_depth_encoder(self, tmp_path, features_factory):
|
||||
"""With streaming video encoding: depth → DepthEncoderConfig, RGB does not."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=True)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / "ds",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=self.NUM_FRAMES)
|
||||
|
||||
encoder = dataset.writer._streaming_encoder
|
||||
assert encoder is not None
|
||||
assert isinstance(encoder._threads[DEPTH_KEY].video_encoder, DepthEncoderConfig)
|
||||
assert not isinstance(encoder._threads[RGB_KEY].video_encoder, DepthEncoderConfig)
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
@@ -94,7 +94,7 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
with pytest.raises(ValueError, match="Unsupported single-channel image dtype"):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -391,7 +385,8 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
# Verify codec options include thread tuning for libsvtav1 (lp=…)
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options
|
||||
codec_opts = thread.video_encoder.get_codec_options(encoder_threads=thread.encoder_threads)
|
||||
assert "svtav1-params" in codec_opts or "threads" in codec_opts
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
|
||||
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.datasets.image_writer import write_image
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
@@ -37,7 +37,15 @@ from lerobot.datasets.video_utils import (
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_VIDEO_INFO
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_DEPTH_FEATURES,
|
||||
DUMMY_DEPTH_KEY,
|
||||
DUMMY_DEPTH_VIDEO_INFO_FULL,
|
||||
DUMMY_VIDEO_FEATURES,
|
||||
DUMMY_VIDEO_INFO,
|
||||
DUMMY_VIDEO_KEY,
|
||||
)
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
|
||||
# Per-codec skip markers — validation tests only fire when the codec is available
|
||||
@@ -48,12 +56,67 @@ def _require_encoder(vcodec: str) -> pytest.MarkDecorator:
|
||||
|
||||
require_libsvtav1 = _require_encoder("libsvtav1")
|
||||
require_h264 = _require_encoder("h264")
|
||||
require_hevc = _require_encoder("hevc")
|
||||
require_videotoolbox = _require_encoder("h264_videotoolbox")
|
||||
require_nvenc = _require_encoder("h264_nvenc")
|
||||
require_vaapi = _require_encoder("h264_vaapi")
|
||||
require_qsv = _require_encoder("h264_qsv")
|
||||
|
||||
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
|
||||
|
||||
|
||||
def _write_color_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
|
||||
|
||||
|
||||
def _write_depth_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
"""Write synthetic uint16 depth TIFFs (millimetres) for depth encoder tests.
|
||||
|
||||
Uses a smooth linear ramp + per-frame offset (not white noise) so HEVC Main 12
|
||||
on ``gray12le`` compresses well. Values span ~100 mm to 10 m, covering most
|
||||
of the default ``[DEPTH_MIN, DEPTH_MAX]`` metres range after
|
||||
``quantize_depth(input_unit="auto"="mm")``.
|
||||
"""
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
base = np.linspace(100.0, 10_000.0, height * width, dtype=np.float32).reshape(height, width)
|
||||
for i in range(num_frames):
|
||||
arr = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.tiff")
|
||||
|
||||
|
||||
def _encode_video(
|
||||
path: Path,
|
||||
num_frames: int = 4,
|
||||
fps: int = 30,
|
||||
cfg: VideoEncoderConfig | None = None,
|
||||
depth: bool = False,
|
||||
) -> Path:
|
||||
"""Write synthetic frames to a temp dir and encode them to ``path``.
|
||||
|
||||
``depth=False`` writes uint8 RGB PNG noise and encodes with ``cfg``
|
||||
(defaulting to the library default). ``depth=True`` writes synthetic uint16
|
||||
depth TIFFs and encodes with ``cfg`` or a default :class:`DepthEncoderConfig`
|
||||
(HEVC Main 12 / ``gray12le``).
|
||||
"""
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
if depth:
|
||||
_write_depth_frames(imgs_dir, num_frames=num_frames)
|
||||
cfg = cfg or DepthEncoderConfig()
|
||||
else:
|
||||
_write_color_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, video_encoder=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
def _read_feature_info(dataset: LeRobotDataset, key: str = DUMMY_VIDEO_KEY) -> dict:
|
||||
info = json.loads((dataset.root / INFO_PATH).read_text())
|
||||
return info["features"][key]["info"]
|
||||
|
||||
|
||||
# ─── VideoEncoderConfig / codec options ──────────────────────────────
|
||||
|
||||
|
||||
@@ -87,7 +150,7 @@ class TestCodecOptions:
|
||||
assert opts["q:v"] == 40
|
||||
assert "crf" not in opts
|
||||
|
||||
@_require_encoder("h264_nvenc")
|
||||
@require_nvenc
|
||||
def test_nvenc_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_nvenc", g=2, crf=25, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
@@ -96,12 +159,12 @@ class TestCodecOptions:
|
||||
assert "crf" not in opts
|
||||
assert opts["g"] == 2
|
||||
|
||||
@_require_encoder("h264_vaapi")
|
||||
@require_vaapi
|
||||
def test_vaapi_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_vaapi", crf=28, preset=None)
|
||||
assert cfg.get_codec_options()["qp"] == 28
|
||||
|
||||
@_require_encoder("h264_qsv")
|
||||
@require_qsv
|
||||
def test_qsv_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_qsv", crf=25, preset=None)
|
||||
assert cfg.get_codec_options()["global_quality"] == 25
|
||||
@@ -313,59 +376,6 @@ class TestEncoderDetection:
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
|
||||
|
||||
# Default video feature set used by persistence tests.
|
||||
VIDEO_FEATURES = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
VIDEO_KEY = "observation.images.cam"
|
||||
|
||||
|
||||
def _write_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
|
||||
|
||||
|
||||
def _encode_video(
|
||||
path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None
|
||||
) -> Path:
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
_write_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
def _read_feature_info(dataset: LeRobotDataset) -> dict:
|
||||
info = json.loads((dataset.root / INFO_PATH).read_text())
|
||||
return info["features"][VIDEO_KEY]["info"]
|
||||
|
||||
|
||||
def _add_frames(dataset: LeRobotDataset, num_frames: int, video_keys: list[str] | None = None) -> None:
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
|
||||
if video_keys is None:
|
||||
video_keys = dataset.meta.video_keys
|
||||
for _ in range(num_frames):
|
||||
frame: dict = {"task": "test"}
|
||||
for key, ft in dataset.meta.features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.zeros(shape, dtype=np.float32)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
class TestGetVideoInfo:
|
||||
def test_returns_all_stream_fields(self):
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4")
|
||||
@@ -375,7 +385,7 @@ class TestGetVideoInfo:
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
assert "video.g" not in info
|
||||
assert "video.crf" not in info
|
||||
@@ -385,7 +395,7 @@ class TestGetVideoInfo:
|
||||
def test_merges_encoder_config_as_video_prefixed_entries(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
|
||||
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
@@ -398,11 +408,16 @@ class TestGetVideoInfo:
|
||||
def test_stream_derived_keys_take_precedence_over_config(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
|
||||
|
||||
assert info["video.codec"] # populated from stream, not from config's vcodec
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
def test_depth_encoder_config_sets_is_depth_map_true(self):
|
||||
"""A ``DepthEncoderConfig`` causes ``get_video_info`` to mark the stream as depth."""
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=DepthEncoderConfig())
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
|
||||
class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
@@ -434,7 +449,7 @@ class TestEncodeVideoFrames:
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
_write_color_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing content"
|
||||
video_path.write_bytes(sentinel)
|
||||
@@ -446,7 +461,7 @@ class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
_write_color_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
@@ -461,7 +476,7 @@ class TestEncodeVideoFrames:
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
|
||||
|
||||
info = get_video_info(video_path, camera_encoder=cfg)
|
||||
info = get_video_info(video_path, video_encoder=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
@@ -470,7 +485,7 @@ class TestEncodeVideoFrames:
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
# Encoder config
|
||||
assert info["video.g"] == 4
|
||||
@@ -488,14 +503,14 @@ class TestReencodeVideo:
|
||||
src = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
|
||||
out = tmp_path / "reencoded.mp4"
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv444p")
|
||||
reencode_video(src, out, camera_encoder=cfg, overwrite=True)
|
||||
reencode_video(src, out, video_encoder=cfg, overwrite=True)
|
||||
|
||||
assert out.exists()
|
||||
with av.open(str(out)) as container:
|
||||
n_frames = sum(1 for _ in container.decode(video=0))
|
||||
assert n_frames == 4
|
||||
|
||||
info = get_video_info(out, camera_encoder=cfg)
|
||||
info = get_video_info(out, video_encoder=cfg)
|
||||
assert info["video.codec"] == "h264"
|
||||
assert info["video.pix_fmt"] == "yuv444p"
|
||||
assert info["video.height"] == 64
|
||||
@@ -509,7 +524,7 @@ class TestReencodeVideo:
|
||||
src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4"
|
||||
out = tmp_path / "trim_window.mp4"
|
||||
cfg = VideoEncoderConfig(vcodec="h264")
|
||||
reencode_video(src, out, camera_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
|
||||
reencode_video(src, out, video_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
frames = list(container.decode(video=0))
|
||||
@@ -580,10 +595,10 @@ class TestEncoderConfigPersistence:
|
||||
def test_first_episode_save_persists_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
root=tmp_path / "ds", features=DUMMY_VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
@@ -603,14 +618,14 @@ class TestEncoderConfigPersistence:
|
||||
def test_second_episode_does_not_overwrite_encoder_fields(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
root=tmp_path / "ds", features=DUMMY_VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
first_info = dict(_read_feature_info(dataset))
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
@@ -637,3 +652,217 @@ class TestFromVideoInfo:
|
||||
# ``{}`` placeholder (typical after a merge with disagreeing sources)
|
||||
# must not leak into the reconstructed config.
|
||||
assert cfg.extra_options == VideoEncoderConfig().extra_options
|
||||
|
||||
|
||||
# ─── Depth-specific encoding tests ────────────────────────────────────
|
||||
|
||||
|
||||
class TestEncodeDepthVideoFrames:
|
||||
"""Depth mirror of :class:`TestEncodeVideoFrames`.
|
||||
|
||||
Exercises ``encode_video_frames`` end-to-end through
|
||||
:class:`DepthEncoderConfig` (HEVC Main 12 / ``gray12le``) on synthetic
|
||||
uint16 depth TIFFs.
|
||||
"""
|
||||
|
||||
@require_hevc
|
||||
def test_produces_readable_file(self, tmp_path):
|
||||
video_path = _encode_video(tmp_path / "out.mp4", depth=True)
|
||||
|
||||
assert video_path.exists()
|
||||
info = get_video_info(video_path, video_encoder=DepthEncoderConfig())
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.codec"] == "hevc"
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["video.channels"] == 1
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
@require_hevc
|
||||
def test_frame_count_and_duration_match_input(self, tmp_path):
|
||||
num_frames = 10
|
||||
fps = 30
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=num_frames, fps=fps, depth=True)
|
||||
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
actual_frames = sum(1 for _ in container.decode(stream))
|
||||
duration = (
|
||||
float(stream.duration * stream.time_base)
|
||||
if stream.duration is not None
|
||||
else float(container.duration / av.time_base)
|
||||
)
|
||||
|
||||
assert actual_frames == num_frames
|
||||
assert abs(duration - num_frames / fps) < 0.1
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
"""Codec-agnostic: file-system semantics must hold even without an HEVC encoder."""
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_depth_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing depth content"
|
||||
video_path.write_bytes(sentinel)
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, video_encoder=DepthEncoderConfig(), overwrite=False)
|
||||
|
||||
assert video_path.read_bytes() == sentinel
|
||||
|
||||
@require_hevc
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_depth_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, video_encoder=DepthEncoderConfig(), overwrite=True)
|
||||
|
||||
info = get_video_info(video_path, video_encoder=DepthEncoderConfig())
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
@require_hevc
|
||||
def test_custom_encoder_config_fields_stored_in_info(self, tmp_path):
|
||||
"""All stream-derived and depth-encoder config fields are present after encoding."""
|
||||
cfg = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=4,
|
||||
crf=25,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.5,
|
||||
use_log=False,
|
||||
)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg, depth=True)
|
||||
|
||||
info = get_video_info(video_path, video_encoder=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.channels"] == 1
|
||||
assert info["video.codec"] == "hevc"
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["is_depth_map"] is True
|
||||
assert info["has_audio"] is False
|
||||
# Base encoder config
|
||||
assert info["video.g"] == 4
|
||||
assert info["video.crf"] == 25
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
# Depth-specific tuning
|
||||
assert info["video.depth_min"] == 0.05
|
||||
assert info["video.depth_max"] == 8.0
|
||||
assert info["video.shift"] == 2.5
|
||||
assert info["video.use_log"] is False
|
||||
|
||||
|
||||
class TestDepthEncoderConfigPersistence:
|
||||
"""Depth mirror of :class:`TestEncoderConfigPersistence`.
|
||||
|
||||
``DepthEncoderConfig`` must be stored as ``video.<field>`` entries
|
||||
(including the depth-specific ``depth_min`` / ``depth_max`` / ``shift`` /
|
||||
``use_log``) under ``info["features"][<depth_key>]["info"]`` when the
|
||||
first episode is saved.
|
||||
"""
|
||||
|
||||
@require_hevc
|
||||
def test_first_episode_save_persists_depth_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=2,
|
||||
crf=30,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.5,
|
||||
use_log=False,
|
||||
)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=DUMMY_DEPTH_FEATURES, use_videos=True, depth_encoder=cfg
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
info = _read_feature_info(dataset, key=DUMMY_DEPTH_KEY)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.codec"] == "hevc"
|
||||
assert info["video.pix_fmt"] == "gray12le"
|
||||
assert info["is_depth_map"] is True
|
||||
# Base encoder config
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
# Depth-specific tuning
|
||||
assert info["video.depth_min"] == 0.05
|
||||
assert info["video.depth_max"] == 8.0
|
||||
assert info["video.shift"] == 2.5
|
||||
assert info["video.use_log"] is False
|
||||
|
||||
@require_hevc
|
||||
def test_second_episode_does_not_overwrite_depth_encoder_fields(
|
||||
self, tmp_path, empty_lerobot_dataset_factory
|
||||
):
|
||||
cfg = DepthEncoderConfig(
|
||||
vcodec="hevc",
|
||||
pix_fmt="gray12le",
|
||||
g=2,
|
||||
crf=30,
|
||||
depth_min=0.05,
|
||||
depth_max=8.0,
|
||||
shift=2.5,
|
||||
use_log=False,
|
||||
)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=DUMMY_DEPTH_FEATURES, use_videos=True, depth_encoder=cfg
|
||||
)
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
first_info = dict(_read_feature_info(dataset, key=DUMMY_DEPTH_KEY))
|
||||
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert _read_feature_info(dataset, key=DUMMY_DEPTH_KEY) == first_info
|
||||
|
||||
|
||||
class TestDepthFromVideoInfo:
|
||||
"""``DepthEncoderConfig.from_video_info`` reconstructs a depth encoder
|
||||
config from the ``video.*`` keys persisted in a dataset's ``info.json``.
|
||||
|
||||
Depth mirror of :class:`TestFromVideoInfo`.
|
||||
"""
|
||||
|
||||
@require_hevc
|
||||
def test_reconstructs_from_dummy_depth_video_info(self):
|
||||
cfg = DepthEncoderConfig.from_video_info(DUMMY_DEPTH_VIDEO_INFO_FULL)
|
||||
|
||||
# No alias for ``"hevc"``; the canonical stream codec is reused as-is.
|
||||
assert cfg.vcodec == "hevc"
|
||||
assert cfg.pix_fmt == DUMMY_DEPTH_VIDEO_INFO_FULL["video.pix_fmt"]
|
||||
assert cfg.g == DUMMY_DEPTH_VIDEO_INFO_FULL["video.g"]
|
||||
assert cfg.crf == DUMMY_DEPTH_VIDEO_INFO_FULL["video.crf"]
|
||||
assert cfg.fast_decode == DUMMY_DEPTH_VIDEO_INFO_FULL["video.fast_decode"]
|
||||
assert cfg.video_backend == DUMMY_DEPTH_VIDEO_INFO_FULL["video.video_backend"]
|
||||
# ``{}`` placeholder (typical after a merge with disagreeing sources)
|
||||
# must not leak into the reconstructed config.
|
||||
assert cfg.extra_options == DepthEncoderConfig().extra_options
|
||||
# Depth-specific tuning round-trips through ``info.json``.
|
||||
assert cfg.depth_min == DUMMY_DEPTH_VIDEO_INFO_FULL["video.depth_min"]
|
||||
assert cfg.depth_max == DUMMY_DEPTH_VIDEO_INFO_FULL["video.depth_max"]
|
||||
assert cfg.shift == DUMMY_DEPTH_VIDEO_INFO_FULL["video.shift"]
|
||||
assert cfg.use_log == DUMMY_DEPTH_VIDEO_INFO_FULL["video.use_log"]
|
||||
|
||||
Vendored
+45
-1
@@ -39,12 +39,56 @@ DUMMY_VIDEO_INFO = {
|
||||
"video.crf": 30,
|
||||
"video.preset": 12,
|
||||
"video.fast_decode": 0,
|
||||
"video.is_depth_map": False,
|
||||
"is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
|
||||
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
|
||||
}
|
||||
DUMMY_DEPTH_VIDEO_INFO = {
|
||||
**DUMMY_VIDEO_INFO,
|
||||
"is_depth_map": True,
|
||||
}
|
||||
DUMMY_DEPTH_VIDEO_INFO_FULL = {
|
||||
**{k: v for k, v in DUMMY_VIDEO_INFO.items() if k != "video.preset"},
|
||||
"video.codec": "hevc",
|
||||
"video.pix_fmt": "gray12le",
|
||||
"is_depth_map": True,
|
||||
"video.depth_min": 0.05,
|
||||
"video.depth_max": 8.0,
|
||||
"video.shift": 2.5,
|
||||
"video.use_log": True,
|
||||
}
|
||||
DUMMY_DEPTH_CAMERA_FEATURES = {
|
||||
"laptop_depth": {
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": DUMMY_DEPTH_VIDEO_INFO,
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
||||
# Default video feature set used by video-encoding persistence tests.
|
||||
DUMMY_VIDEO_FEATURES = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
DUMMY_VIDEO_KEY = "observation.images.cam"
|
||||
|
||||
DUMMY_DEPTH_FEATURES = {
|
||||
"observation.images.depth": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": True},
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
DUMMY_DEPTH_KEY = "observation.images.depth"
|
||||
|
||||
Vendored
+38
@@ -49,6 +49,39 @@ from tests.fixtures.constants import (
|
||||
)
|
||||
|
||||
|
||||
def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
"""Append ``num_frames`` synthetic frames to ``dataset``.
|
||||
|
||||
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
|
||||
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
|
||||
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
|
||||
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
|
||||
"""
|
||||
video_keys = dataset.meta.video_keys
|
||||
depth_keys = dataset.meta.depth_keys
|
||||
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
|
||||
# encode (HEVC Main 12 hates white noise).
|
||||
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
|
||||
for i in range(num_frames):
|
||||
frame: dict = {"task": "test"}
|
||||
for key, ft in dataset.meta.features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if key in depth_keys:
|
||||
h, w, _ = shape
|
||||
base = _depth_base_cache.setdefault(
|
||||
(h, w),
|
||||
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
)
|
||||
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
elif key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.zeros(shape, dtype=np.float32)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
@@ -485,10 +518,14 @@ def lerobot_dataset_factory(
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
camera_features: dict | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info_kwargs = {}
|
||||
if camera_features is not None:
|
||||
info_kwargs["camera_features"] = camera_features
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
@@ -496,6 +533,7 @@ def lerobot_dataset_factory(
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
chunks_size=chunks_size,
|
||||
**info_kwargs,
|
||||
)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info.features)
|
||||
|
||||
@@ -2370,14 +2370,32 @@ def test_aggregate_images_when_use_videos_false():
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=False, # expect "image" dtype
|
||||
use_videos=False, # images kept, stored as "image" dtype
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = f"{OBS_IMAGES}.back"
|
||||
key_front = f"{OBS_IMAGES}.front"
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
assert key in out
|
||||
assert key_front in out
|
||||
assert out[key]["dtype"] == "image"
|
||||
assert out[key_front]["dtype"] == "image"
|
||||
assert out[key]["shape"] == initial["back"]
|
||||
|
||||
|
||||
def test_aggregate_images_excluded():
|
||||
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
|
||||
initial = {"back": (480, 640, 3)}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
exclude_images=True,
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
assert f"{OBS_IMAGES}.back" not in out
|
||||
assert f"{OBS_IMAGES}.front" not in out
|
||||
|
||||
|
||||
def test_aggregate_images_when_use_videos_true():
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
ReencodeVideosConfig,
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
_validate_config,
|
||||
@@ -103,3 +104,47 @@ class TestOperationTypeParsing:
|
||||
)
|
||||
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
||||
assert resolved_name == type_name
|
||||
|
||||
|
||||
class TestDepthEncoderParsing:
|
||||
"""Test that the depth encoder is exposed and parsed for video operations."""
|
||||
|
||||
def test_reencode_has_default_depth_encoder(self):
|
||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", "reencode_videos"])
|
||||
assert isinstance(cfg.operation, ReencodeVideosConfig)
|
||||
# A depth encoder is configured by default so depth videos are re-encoded too.
|
||||
assert cfg.operation.depth_encoder is not None
|
||||
assert hasattr(cfg.operation.depth_encoder, "depth_min")
|
||||
|
||||
def test_reencode_parses_depth_encoder_overrides(self):
|
||||
cfg = parse_cfg(
|
||||
[
|
||||
"--repo_id",
|
||||
"test/repo",
|
||||
"--operation.type",
|
||||
"reencode_videos",
|
||||
"--operation.depth_encoder.vcodec",
|
||||
"ffv1",
|
||||
"--operation.depth_encoder.depth_max",
|
||||
"12.0",
|
||||
"--operation.depth_encoder.use_log",
|
||||
"false",
|
||||
]
|
||||
)
|
||||
assert cfg.operation.depth_encoder.vcodec == "ffv1"
|
||||
assert cfg.operation.depth_encoder.depth_max == 12.0
|
||||
assert cfg.operation.depth_encoder.use_log is False
|
||||
|
||||
def test_convert_image_to_video_parses_depth_encoder_overrides(self):
|
||||
cfg = parse_cfg(
|
||||
[
|
||||
"--repo_id",
|
||||
"test/repo",
|
||||
"--operation.type",
|
||||
"convert_image_to_video",
|
||||
"--operation.depth_encoder.depth_min",
|
||||
"0.05",
|
||||
]
|
||||
)
|
||||
assert isinstance(cfg.operation, ConvertImageToVideoConfig)
|
||||
assert cfg.operation.depth_encoder.depth_min == 0.05
|
||||
|
||||
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
|
||||
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
|
||||
from lerobot.teleoperators.rebot_102_leader import (
|
||||
RebotArm102Leader,
|
||||
RebotArm102LeaderConfig,
|
||||
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
|
||||
|
||||
def test_bimanual_prefixes_features():
|
||||
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
||||
cfg = BiRebotArm102LeaderConfig(
|
||||
cfg = BiRebot102LeaderConfig(
|
||||
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
|
||||
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
|
||||
)
|
||||
teleop = BiRebotArm102Leader(cfg)
|
||||
teleop = BiRebot102Leader(cfg)
|
||||
assert any(k.startswith("left_") for k in teleop.action_features)
|
||||
assert any(k.startswith("right_") for k in teleop.action_features)
|
||||
assert "left_gripper.pos" in teleop.action_features
|
||||
|
||||
Reference in New Issue
Block a user