Compare commits

..

65 Commits

Author SHA1 Message Date
CarolinePascal 45f49b6600 test(visualization): fixing visualization tests 2026-06-16 19:01:46 +02:00
CarolinePascal c56c6991d1 chore(rebase): fixing rebase merge conflicts 2026-06-16 18:04:33 +02:00
CarolinePascal e8f6c32623 fix(update video info): ditching the differentiated approahces for video info update - video info are always updated unless for preserved keys. 2026-06-16 18:02:59 +02:00
CarolinePascal cc17582f71 chore(format): format code 2026-06-16 18:02:59 +02:00
CarolinePascal 529f48b540 fix(review): add Claude review 2026-06-16 18:02:59 +02:00
CarolinePascal c128027415 test(reencode): fixing reencoding monkeypatch 2026-06-16 18:02:59 +02:00
CarolinePascal c0db93f4a0 fix(update video info): fixing update video info logic to match the recording and editing use cases 2026-06-16 18:02:59 +02:00
CarolinePascal 1f66e6f5e4 fix(save images): fixing image saving in dataset tools 2026-06-16 18:02:59 +02:00
CarolinePascal 7a67ce9e50 docs(dataset tools): updating docs 2026-06-16 18:02:59 +02:00
CarolinePascal 4f26878d8f docs(docstring): updating docstrings 2026-06-16 18:02:59 +02:00
CarolinePascal a694e32774 feat(dataset tools): adding missing docstrings and features for depth fill support in dataset edition tools 2026-06-16 18:02:59 +02:00
CarolinePascal 655338abf3 fix(rebase): rebase follow up corrections 2026-06-16 18:02:59 +02:00
CarolinePascal 364d4de96f docs(mermaid): fixing mermaid diagram 2026-06-16 18:02:59 +02:00
CarolinePascal 41a942658b fix(pyav check): fixing PyAV option validation for integer codec options by normalizing
numeric values before calling `is_integer()`

Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
2026-06-16 18:02:59 +02:00
CarolinePascal 030d9a279a chore(format): formatting code 2026-06-16 18:02:59 +02:00
CarolinePascal 30479cf277 test(dataset tools): adding missing tests for new dataset edition tools features 2026-06-16 18:02:58 +02:00
CarolinePascal cb6b2d77bd test(fix): fixing depth tests 2026-06-16 18:02:58 +02:00
CarolinePascal 76f79f3955 docs(depth): improving depth maps docs 2026-06-16 18:02:58 +02:00
CarolinePascal 9e994baa04 chore(format): formatting code 2026-06-16 18:02:56 +02:00
CarolinePascal 6fd911ebb9 test(depth encoding): updating and cleaning video/depth encoding tests 2026-06-16 18:02:31 +02:00
CarolinePascal f712698272 test(depth): cleaning up depth tests 2026-06-16 18:02:31 +02:00
CarolinePascal c2416ecbcb feat(output unit): adding support for output unit specification at dataset reading/training time
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
2026-06-16 18:02:31 +02:00
CarolinePascal 6aa50cc1e5 fix(depth units): fixing depth units output for the realsense cameras 2026-06-16 18:02:31 +02:00
CarolinePascal e17adce3ba fix(is_depth): adding missing doctrings and is_depth arguments in video decoding functions
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
2026-06-16 18:02:31 +02:00
CarolinePascal f7010ff66c fix(typo): fixing typo 2026-06-16 18:02:31 +02:00
CarolinePascal f7ee453de7 fix(from_video_info): fixing early validation issue in from_video_info 2026-06-16 18:02:31 +02:00
CarolinePascal ca7168f413 test(cleaning): cleaning up tests 2026-06-16 18:02:28 +02:00
CarolinePascal ec6264d768 test(aggregate): extending aggregation tests to depth frames 2026-06-16 18:01:55 +02:00
CarolinePascal d93a58a8b8 feat(tools): adding depth support in LeRobotDataset edition tools 2026-06-16 18:01:50 +02:00
CarolinePascal 92497dfcd8 feat(batched dequantization): optimizing dequantize_depth for torch based batched dequantization 2026-06-16 18:01:18 +02:00
CarolinePascal 263108d6c1 fix(TIFF): add missing quantization and cleanup for TIFF files 2026-06-16 18:01:18 +02:00
CarolinePascal a925d20ce4 fix(typo): fixing typo 2026-06-16 18:01:18 +02:00
CarolinePascal 1f024ea3bf fix(normalization): restricting 255 normalization to non depth/uint8 images only 2026-06-16 18:01:18 +02:00
CarolinePascal d5f67cc7fc fix(realsense): fixing typo in realsense serial number 2026-06-16 18:01:18 +02:00
CarolinePascal 9ab8c98494 tests(typos): fixing typos in tests 2026-06-16 18:01:18 +02:00
CarolinePascal a561183442 fix(info): fixing info metadata update when is_depth_map was set 2026-06-16 18:01:18 +02:00
CarolinePascal 305b8d64b2 fix(pre-commit): fixing mutable defautl value 2026-06-16 18:01:18 +02:00
CarolinePascal 0a624a5cf5 feat(refactor): refactor DepthEncoderConfig quantization pipeline, so that the methods do not live in the config class. Add pixel format - channels validation.Move the default pixel format for depth in the config file. 2026-06-16 18:01:18 +02:00
CarolinePascal d044ead377 feat(pix_fmt channels): use PyAv to check get pixel formats number of channels 2026-06-16 18:01:18 +02:00
CarolinePascal e425fcb61a tests(depth): adding new tests for depth integration validation 2026-06-16 18:01:17 +02:00
CarolinePascal f08a9aea71 test(fix): fixing exisiting tests to still work with latest features 2026-06-16 18:01:17 +02:00
CarolinePascal 7d97b55cc4 chore(typos): fixing typos 2026-06-16 18:01:17 +02:00
CarolinePascal edbd8c6f82 fix(plumbing): fixing missing parts in the depth maps pipeline 2026-06-16 18:01:17 +02:00
CarolinePascal 615954b80b fix(stop_event): fixing stop_event race condition in camera classes 2026-06-16 18:01:17 +02:00
CarolinePascal 1c0fdfdb4b feat(is_depth): simplifying is_depth nested name + legacy support 2026-06-16 18:01:17 +02:00
CarolinePascal 1c3ebd475f feat(depth shape): ensuring depth maps shape is always including the channel 2026-06-16 18:01:17 +02:00
CarolinePascal c655814788 chore(format): format code 2026-06-16 18:01:17 +02:00
CarolinePascal a72ab14f89 feat(depth maps writer): adding support for raw depth maps recording with image writer 2026-06-16 18:01:17 +02:00
CarolinePascal 882074d707 feat(viz): render depth observations as rr.DepthImage in Viridis 2026-06-16 18:01:17 +02:00
CarolinePascal 4ae2f9f375 feat(record): plumb DepthEncoderConfig through lerobot-record 2026-06-16 18:01:17 +02:00
CarolinePascal 26099b6e03 feat(robots/so_follower): emit + populate depth keys when use_depth 2026-06-16 18:01:16 +02:00
CarolinePascal 6b395dfb24 feat(features): route 2D camera shapes to observation.depth.<key> 2026-06-16 18:01:16 +02:00
CarolinePascal 1cbabfe9a4 feat(cameras/realsense): expose async depth in metric meters 2026-06-16 18:01:16 +02:00
CarolinePascal 4744f4b913 feat(depth): wire DatasetReader to decode_depth_frames 2026-06-16 18:01:16 +02:00
CarolinePascal 9568e68b28 feat(depth): wire StreamingVideoEncoder + writer to depth encoder 2026-06-16 18:01:16 +02:00
CarolinePascal 10941c31f6 feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter 2026-06-16 18:01:16 +02:00
CarolinePascal a6882a048a feat(depth): extend quantization tools to better fit the encoding/decoding pipeline 2026-06-16 18:01:16 +02:00
CarolinePascal eb2b7d6dc3 feat(depth): persist depth metadata 2026-06-16 18:01:16 +02:00
CarolinePascal f7f7b8c7f8 feat(video): add ffv1 to supported codecs 2026-06-16 18:01:16 +02:00
CarolinePascal d58a324da4 feat(depth): add depth quantization helpers and tests 2026-06-16 18:01:16 +02:00
Caroline Pascal 287c823f13 fix(features copy): adding deepcopy on LeRobot dataset features to avoid shallow copy leaks (#3826)
* fix(features copy): adding deepcopy on LeRobot dataset features to avoid shallow copy leaks

* tests(test): adding new test
2026-06-16 17:58:59 +02:00
Pepijn 58ccc01508 fix(datasets): enforce one parquet row group per episode in v3 data writes (#3807)
* fix(datasets): enforce one parquet row group per episode in v3 data writes

LeRobot v3 data shards must hold exactly one row group per episode so a
reader can fetch episode i with pq.ParquetFile(path).read_row_group(i)
(a byte-range read) instead of loading the whole shard. The recording
writer already does this (one write_table per episode); the aggregate
and lerobot-annotate re-write paths instead concatenated many episodes
and wrote them in one shot, collapsing the file to a single row group.

- io_utils: add write_table_one_row_group_per_episode (one ParquetWriter,
  one write_table per episode — same pattern as the recording writer);
  to_parquet_with_hf_images embeds images then writes per-episode row
  groups; to_parquet_one_row_group_per_episode wraps it for plain frames
- aggregate: route non-image data writes through the per-episode writer;
  leave the episodes-metadata parquet untouched (already one row/episode)
- annotate: rewrite shards via the per-episode writer instead of a single
  bulk pq.write_table
- tests: invariant coverage through the aggregate (image + video) and
  annotate paths

No change to on-disk schema, paths, naming, rollover thresholds, or
compression. Readers stay backward-compatible (old collapsed files load).

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(datasets): correct indentation and add strict= in row-group helper

The web-edited numpy version of write_table_one_row_group_per_episode had an
over-indented line (IndentationError, breaking pre-commit + test collection)
and a zip() without strict=. Fix both; behaviour unchanged.

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-16 12:15:48 +02:00
Caroline Pascal 38327fdc84 fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion (#3783)
* fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion when videos are not used

* fix(docstrings): improving docstrings

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-15 17:55:52 +02:00
Steven Palma 9555efc02c chore(dependencies): update uv.lock (#3595)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-06-15 16:29:44 +02:00
Steven Palma d576c59afb refactor(robots): homogenize bi-manual setups implementations (#3772)
* chore(robots): homogenize bi setups

* feat(robots): split openarm mini into single and bi

* refactor(robots): mixin for bi classes

* docs: update docs
2026-06-15 16:28:54 +02:00
82 changed files with 3762 additions and 1695 deletions
+8
View File
@@ -157,6 +157,14 @@ finally:
</hfoption> </hfoption>
</hfoptions> </hfoptions>
### Working with depth
The Intel RealSense and Reachy 2 cameras can capture both color and depth in lockstep. Calling `read()` returns the **color** frame as `(H, W, 3)` `uint8`. Calling `read_depth()` returns the **depth map** as `(H, W, 1)` `uint16`, where each pixel value is the distance from the sensor expressed in **millimetres**. A pixel value of `0` typically means "no measurement available" (out-of-range, occluded, or low-confidence).
During recording, the control loop peeks the freshest buffered frames non-blockingly via `read_latest()` (color) and `read_latest_depth()` (depth), adding the depth map as a sibling feature (e.g. `front_depth` next to `front`).
For how depth streams are stored and encoded when recording a dataset, see the [Depth streams](./video_encoding_parameters#depth-streams) section of the video encoding guide.
## Use your phone's camera ## Use your phone's camera
<hfoptions id="use phone"> <hfoptions id="use phone">
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:** **Compatible teleoperators:**
- `openarm_mini` - OpenArm Mini - `bi_openarm_mini` - Bimanual OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm - `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT] > [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `openarm_mini`. > The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags. > `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
--- ---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \ --robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \ --robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \ --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \ --teleop.type=bi_openarm_mini \
--teleop.port_left=/dev/ttyACM0 \ --teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \ --teleop.right_arm_config.port=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \ --dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \ --robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \ --robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \ --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \ --teleop.type=bi_openarm_mini \
--teleop.port_left=/dev/ttyACM0 \ --teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \ --teleop.right_arm_config.port=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \ --dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \ --strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \ --robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \ --teleop.type=bi_openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \ --dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt" --dataset.single_task="Fold the T-shirt"
``` ```
+45 -4
View File
@@ -11,8 +11,9 @@ LeRobot provides several utilities for manipulating datasets:
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids` 3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset 4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset 5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage 6. **Convert to Video** - Convert image-based datasets to video format for efficient storage (RGB and depth cameras are encoded with separate encoders)
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc. 7. **Re-encode Videos** - Re-encode an existing video dataset's RGB and/or depth streams with new encoder settings
8. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
The core implementation is in `lerobot.datasets.dataset_tools`. The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
@@ -122,6 +123,15 @@ lerobot-edit-dataset \
--operation.camera_encoder.g 2 \ --operation.camera_encoder.g 2 \
--operation.camera_encoder.crf 30 --operation.camera_encoder.crf 30
# Convert a dataset that includes depth maps, customizing the depth encoder
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir outputs/pusht_video \
--operation.depth_encoder.depth_min 0.01 \
--operation.depth_encoder.depth_max 10.0 \
--operation.depth_encoder.use_log true
# Convert only specific episodes # Convert only specific episodes
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
@@ -147,11 +157,42 @@ lerobot-edit-dataset \
**Parameters:** **Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`) - `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details. - `camera_encoder`: Video encoder settings applied to RGB cameras — all sub-fields accessible via `--operation.camera_encoder.<field>`. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
- `depth_encoder`: Video encoder settings applied to depth-map cameras (e.g. from an Intel RealSense). In addition to the standard encoder fields it exposes the depth quantization knobs (`depth_min`, `depth_max`, `shift`, `use_log`), accessible via `--operation.depth_encoder.<field>`. These quantization settings are persisted to the dataset metadata so depth can be dequantized back to physical units on load. See the [Depth streams](./video_encoding_parameters#depth-streams) section for details.
- `episode_indices`: List of specific episodes to convert (default: all episodes) - `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4) - `num_workers`: Number of parallel workers for processing (default: 4)
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved. **Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). Depth-map cameras are detected automatically and routed to the `depth_encoder`, while RGB cameras use the `camera_encoder`. All episodes, stats, and tasks are preserved.
#### Re-encode Videos
Re-encode the videos of an existing video dataset with different encoder settings, without going back to raw frames. RGB videos use the `camera_encoder` and depth videos use the `depth_encoder`. Provide only the encoder(s) you want to re-encode; the other stream type is left untouched.
```bash
# Re-encode all RGB videos with new settings (saves to lerobot/pusht_reencoded by default)
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec h264 \
--operation.camera_encoder.pix_fmt yuv420p \
--operation.camera_encoder.crf 23
# Re-encode both RGB and depth videos in a dataset with depth maps
lerobot-edit-dataset \
--repo_id lerobot/pusht_depth \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec libx264 \
--operation.depth_encoder.vcodec ffv1
```
**Parameters:**
- `camera_encoder`: Encoder settings applied to every RGB video. Omit to skip re-encoding RGB videos.
- `depth_encoder`: Encoder settings applied to every depth video. Omit to skip re-encoding depth videos.
- `num_workers`: Number of parallel workers for processing.
> [!NOTE]
> When re-encoding depth videos, the existing depth quantization parameters (`depth_min`, `depth_max`, `shift`, `use_log`) and the `is_depth_map` flag are **preserved** — re-encoding only changes the codec/quality of the stored stream, not how depth is dequantized on load.
### Show the information of datasets ### Show the information of datasets
+72 -2
View File
@@ -65,6 +65,76 @@ All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
--- ---
## Depth streams
Depth maps (Intel RealSense, Reachy 2) are stored as their **own video streams** alongside the RGB streams. Raw depth (`uint16` millimetres or `float32` metres) can't survive an 8-bit codec, so LeRobot **quantizes** each map to a 12-bit code (`[0, 4095]`) — logarithmically by default, to match the `1/depth` error profile of depth sensors — then packs it into a high-bit-depth pixel format (`gray12le`) and encodes it with a 12-bit codec.
```mermaid
flowchart LR
A["Raw depth (uint16 mm / float32 m)"] --> B["Clip to depth_min, depth_max"]
B --> C["Quantize to 12-bit code 04095 (log or linear)"]
C --> D["Pack into gray12le"]
D --> E["Encode video (hevc Main 12)"]
E --> F[("MP4 + metadata: depth_min/max, shift, use_log")]
F -. "load time (depth_output_unit)" .-> G["Dequantize to mm or m"]
classDef input fill:#e3f2fd,stroke:#1565c0,color:#0d47a1;
classDef encode fill:#ede7f6,stroke:#5e35b1,color:#311b92;
classDef store fill:#fff8e1,stroke:#f9a825,color:#e65100;
classDef load fill:#e8f5e9,stroke:#2e7d32,color:#1b5e20;
class A input;
class B,C,D,E encode;
class F store;
class G load;
```
Configure the depth pipeline through a parallel **`depth_encoder`** block (`DepthEncoderConfig`). It inherits every `VideoEncoderConfig` field (`vcodec`, `pix_fmt`, `crf`, …) and adds four quantizer knobs, set via `--dataset.depth_encoder.<field>`:
```bash
lerobot-record \
... \
--dataset.depth_encoder.vcodec=hevc \
--dataset.depth_encoder.depth_min=0.05 \
--dataset.depth_encoder.depth_max=5.0 \
--dataset.depth_encoder.use_log=true
```
| Parameter | Type | Default | Description |
| ----------- | ------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- |
| `vcodec` | `str` | `"hevc"` | Defaults to HEVC Main 12 (a 12-bit-capable codec). `ffv1` is a lossless alternative. |
| `pix_fmt` | `str` | `"gray12le"` | Single-channel 12-bit pixel format used to carry the quantized codes. |
| `depth_min` | `float` | `0.01` | Depth in metres mapped to quantum `0`. Values below are clipped on decode. |
| `depth_max` | `float` | `10.0` | Depth in metres mapped to quantum `4095`. Values above are clipped on decode. |
| `shift` | `float` | `3.5` | Pre-log offset (metres) used in logarithmic quantization for numerical stability near zero. Must satisfy `depth_min + shift > 0`. |
| `use_log` | `bool` | `True` | If `true`, quantize in log-space (recommended for typical depth sensors). Set to `false` for uniform/linear quantization. |
> [!TIP]
> `depth_min`, `depth_max`, and `shift` are always interpreted in **metres**, regardless of the input depth's unit. Inputs are auto-detected: integer arrays (e.g. `uint16` millimetres straight from a RealSense) are treated as millimetres, floating arrays as metres.
> Pick `depth_min` / `depth_max` to bracket the actual working range of your sensor — quanta outside that range saturate, which can crush detail at the boundaries.
Depth features are flagged with `"is_depth_map": true` in `meta/info.json`, and their quantizer settings (`video.depth_min`, `video.depth_max`, `video.shift`, `video.use_log`) are persisted — which is what lets depth be **dequantized back to physical units** on load.
### Output unit at load time
`depth_encoder` is a **record-time** concern. The unit that depth maps are dequantized to on _load_ (e.g. during training) is set separately by the read-time flag `--dataset.depth_output_unit`:
```bash
lerobot-train \
--dataset.repo_id=<my_username>/<my_dataset_name> \
--dataset.depth_output_unit=m \
--policy.type=act
```
| Parameter | Type | Default | Description |
| ------------------- | ----- | ------- | -------------------------------------------------------------------------------------------- |
| `depth_output_unit` | `str` | `"mm"` | Physical unit depth maps are dequantized to on load: `"mm"` (millimetres) or `"m"` (metres). |
> [!TIP]
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
---
## Persistence in dataset metadata ## Persistence in dataset metadata
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is: After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
@@ -82,7 +152,7 @@ After the first episode of a video stream is encoded, the encoder configuration
"video.pix_fmt": "yuv420p", "video.pix_fmt": "yuv420p",
"video.fps": 30, "video.fps": 30,
"video.channels": 3, "video.channels": 3,
"video.is_depth_map": false, "is_depth_map": false,
"video.g": 2, "video.g": 2,
"video.crf": 30, "video.crf": 30,
"video.preset": "fast", "video.preset": "fast",
@@ -97,7 +167,7 @@ After the first episode of a video stream is encoded, the encoder configuration
Two sources contribute to the `info` block: Two sources contribute to the `info` block:
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present. - **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`. - **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
<Tip> <Tip>
@@ -281,7 +281,7 @@ class VideoFrameProvider:
reencode_video( reencode_video(
src, src,
out_path, out_path,
camera_encoder=encoder, video_encoder=encoder,
overwrite=True, overwrite=True,
start_time_s=from_timestamp, start_time_s=from_timestamp,
end_time_s=to_timestamp, end_time_s=to_timestamp,
@@ -54,6 +54,7 @@ from typing import Any
import pyarrow as pa import pyarrow as pa
import pyarrow.parquet as pq import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import ( from lerobot.datasets.language import (
EVENT_ONLY_STYLES, EVENT_ONLY_STYLES,
LANGUAGE_EVENTS, LANGUAGE_EVENTS,
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
new_table = self._materialize_table( new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
) )
# Atomic replace: write to a sibling tmp path and rename so a crash # Re-emit one row group per episode (a bulk pq.write_table would collapse
# mid-write can't leave a half-written shard that ``pq.read_table`` # them into one). Write to a sibling tmp path and atomically rename so a
# would then fail to open. ``Path.replace`` is atomic on POSIX + # crash mid-write can't leave a half-written shard.
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp") tmp_path = path.with_suffix(path.suffix + ".tmp")
pq.write_table(new_table, tmp_path) write_table_one_row_group_per_episode(new_table, tmp_path)
tmp_path.replace(path) tmp_path.replace(path)
def _materialize_table( def _materialize_table(
+3 -2
View File
@@ -105,8 +105,9 @@ def raw_observation_to_observation(
def prepare_image(image: torch.Tensor) -> torch.Tensor: def prepare_image(image: torch.Tensor) -> torch.Tensor:
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor""" """Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
image = image.type(torch.float32) / 255 if image.dtype == torch.uint8:
image = image.type(torch.float32) / 255
image = image.contiguous() image = image.contiguous()
return image return image
+5 -2
View File
@@ -436,7 +436,7 @@ class OpenCVCamera(Camera):
Internal loop run by the background thread for asynchronous reading. Internal loop run by the background thread for asynchronous reading.
On each iteration: On each iteration:
1. Reads a color frame 1. Reads a color frame (blocking call)
2. Stores result in latest_frame and updates timestamp (thread-safe) 2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners 3. Sets new_frame_event to notify listeners
@@ -445,8 +445,9 @@ class OpenCVCamera(Camera):
if self.stop_event is None: if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0 failure_count = 0
while not self.stop_event.is_set(): while not stop_event.is_set():
try: try:
raw_frame = self._read_from_hardware() raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame) processed_frame = self._postprocess_image(raw_frame)
@@ -484,6 +485,8 @@ class OpenCVCamera(Camera):
if self.thread is not None and self.thread.is_alive(): if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0) self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None self.thread = None
self.stop_event = None self.stop_event = None
@@ -268,13 +268,13 @@ class RealSenseCamera(Camera):
) )
if len(found_devices) > 1: if len(found_devices) > 1:
serial_numbers = [dev["serial_number"] for dev in found_devices] serial_numbers = [dev["id"] for dev in found_devices]
raise ValueError( raise ValueError(
f"Multiple RealSense cameras found with name '{name}'. " f"Multiple RealSense cameras found with name '{name}'. "
f"Please use a unique serial number instead. Found SNs: {serial_numbers}" f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
) )
serial_number = str(found_devices[0]["serial_number"]) serial_number = str(found_devices[0]["id"])
return serial_number return serial_number
def _configure_rs_pipeline_config(self, rs_config: Any) -> None: def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
from the camera hardware via the RealSense pipeline. from the camera hardware via the RealSense pipeline.
Returns: Returns:
np.ndarray: The depth map as a NumPy array (height, width) np.ndarray: The depth map as a NumPy array (height, width, 1)
of type `np.uint16` (raw depth values in millimeters) and rotation. of type `np.uint16` (raw depth values in millimeters).
Raises: Raises:
DeviceNotConnectedError: If the camera is not connected. DeviceNotConnectedError: If the camera is not connected.
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
Internal loop run by the background thread for asynchronous reading. Internal loop run by the background thread for asynchronous reading.
On each iteration: On each iteration:
1. Reads a color frame with 500ms timeout 1. Reads a color/depth frame (blocking call with 10s timeout)
2. Stores result in latest_frame and updates timestamp (thread-safe) 2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners 3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -474,8 +474,9 @@ class RealSenseCamera(Camera):
if self.stop_event is None: if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0 failure_count = 0
while not self.stop_event.is_set(): while not stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame() color_frame_raw = frame.get_color_frame()
@@ -486,6 +487,8 @@ class RealSenseCamera(Camera):
depth_frame_raw = frame.get_depth_frame() depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data()) depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True) processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
processed_depth_frame = processed_depth_frame[..., np.newaxis]
capture_time = time.perf_counter() capture_time = time.perf_counter()
@@ -522,6 +525,8 @@ class RealSenseCamera(Camera):
if self.thread is not None and self.thread.is_alive(): if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0) self.thread.join(timeout=2.0)
if self.thread.is_alive(): # pragma: no cover
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None self.thread = None
self.stop_event = None self.stop_event = None
@@ -532,7 +537,6 @@ class RealSenseCamera(Camera):
self.latest_timestamp = None self.latest_timestamp = None
self.new_frame_event.clear() self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected @check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
""" """
@@ -575,7 +579,6 @@ class RealSenseCamera(Camera):
return frame return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected @check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking). """Return the most recent (color) frame captured immediately (Peeking).
@@ -611,6 +614,73 @@ class RealSenseCamera(Camera):
return frame return frame
@check_if_not_connected
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[np.uint16]:
"""Read the latest depth frame asynchronously, in millimeters.
Mirrors :meth:`async_read` but returns the depth stream rather than the
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``, where each
pixel is the distance from the sensor in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
the background read thread is not running.
TimeoutError: If no frame becomes available within ``timeout_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
self.new_frame_event.clear()
if depth_frame is None:
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
return depth_frame
@check_if_not_connected
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent depth frame in millimeters (peeking).
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
Output is ``np.uint16`` of shape ``(H, W, 1)``, where each pixel is the
distance from the sensor in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
no depth frame has been captured yet.
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
timestamp = self.latest_timestamp
if depth_frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any depth frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return depth_frame
def disconnect(self) -> None: def disconnect(self) -> None:
""" """
Disconnects from the camera, stops the pipeline, and cleans up resources. Disconnects from the camera, stops the pipeline, and cleans up resources.
+4 -1
View File
@@ -249,8 +249,9 @@ class ZMQCamera(Camera):
if self.stop_event is None: if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.") raise RuntimeError(f"{self}: stop_event is not initialized.")
stop_event = self.stop_event
failure_count = 0 failure_count = 0
while not self.stop_event.is_set(): while not stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
capture_time = time.perf_counter() capture_time = time.perf_counter()
@@ -292,6 +293,8 @@ class ZMQCamera(Camera):
if self.thread is not None and self.thread.is_alive(): if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0) self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None self.thread = None
self.stop_event = None self.stop_event = None
+9 -11
View File
@@ -180,26 +180,24 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key) self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True) self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items(): for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if not isinstance(v, (int | float | str)): if not isinstance(v, (int | float | str)):
logging.warning( logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
) )
continue continue
batch_data[f"{mode}/{k}"] = v # Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if batch_data:
if custom_step_key is not None: if custom_step_key is not None:
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key] value_custom_step = d[custom_step_key]
self._wandb.log(batch_data) data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
else: self._wandb.log(data)
self._wandb.log(data=batch_data, step=step) continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
+7
View File
@@ -35,8 +35,11 @@ from .types import (
from .video import ( from .video import (
VALID_VIDEO_CODECS, VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS, VIDEO_ENCODER_INFO_KEYS,
DepthEncoderConfig,
VideoEncoderConfig, VideoEncoderConfig,
camera_encoder_defaults, camera_encoder_defaults,
depth_encoder_defaults,
encoder_config_from_video_info,
) )
__all__ = [ __all__ = [
@@ -57,8 +60,12 @@ __all__ = [
"WandBConfig", "WandBConfig",
"load_recipe", "load_recipe",
"VideoEncoderConfig", "VideoEncoderConfig",
"DepthEncoderConfig",
# Defaults # Defaults
"camera_encoder_defaults", "camera_encoder_defaults",
"depth_encoder_defaults",
# Factories
"encoder_config_from_video_info",
# Constants # Constants
"VALID_VIDEO_CODECS", "VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS", "VIDEO_ENCODER_INFO_KEYS",
+3 -1
View File
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from .video import VideoEncoderConfig, camera_encoder_defaults from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
@dataclass @dataclass
@@ -60,6 +60,8 @@ class DatasetRecordConfig:
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys, # Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``). # e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults) camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
# Enable streaming video encoding: encode frames in real-time during capture instead # Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False streaming_encoding: bool = False
+6 -1
View File
@@ -35,12 +35,17 @@ class DatasetConfig:
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_video_backend) video_backend: str = field(default_factory=get_safe_default_video_backend)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0). # When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False return_uint8: bool = False
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
# Has no effect on datasets without depth cameras.
depth_output_unit: str = "mm"
streaming: bool = False streaming: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.depth_output_unit not in ("m", "mm"):
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
if self.episodes is not None: if self.episodes is not None:
if any(ep < 0 for ep in self.episodes): if any(ep < 0 for ep in self.episodes):
raise ValueError( raise ValueError(
+112 -8
View File
@@ -20,7 +20,7 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any, ClassVar, Self
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import require_package
@@ -36,11 +36,12 @@ HW_VIDEO_CODECS = [
"h264_vaapi", # Linux Intel/AMD "h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync "h264_qsv", # Intel Quick Sync
] ]
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS}) VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
)
# Aliases for legacy video codec names. # Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"} VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12 LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`). # Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
@@ -52,6 +53,19 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
) )
# Default depth quantization and encoding parameters.
DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@dataclass @dataclass
class VideoEncoderConfig: class VideoEncoderConfig:
@@ -86,6 +100,10 @@ class VideoEncoderConfig:
video_backend: str = "pyav" video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict) extra_options: dict[str, Any] = field(default_factory=dict)
# Source-data channel count this encoder is expected to handle (3 for RGB,
# 1 for depth, etc.)
_DEFAULT_CHANNELS: ClassVar[int] = 3
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.resolve_vcodec() self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work". # Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
@@ -94,9 +112,9 @@ class VideoEncoderConfig:
self.validate() self.validate()
@classmethod @classmethod
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig: def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block. """Parse the ``video.*`` keys of a feature ``info`` block into
Missing or ``None`` values fall back to the class defaults. constructor kwargs.
""" """
video_info = video_info or {} video_info = video_info or {}
kwargs: dict[str, Any] = {} kwargs: dict[str, Any] = {}
@@ -115,7 +133,15 @@ class VideoEncoderConfig:
continue continue
kwargs[field_name] = value kwargs[field_name] = value
return cls(**kwargs) return kwargs
@classmethod
def from_video_info(cls, video_info: dict | None) -> Self:
"""Reconstruct an encoder config from a video feature's ``info`` block.
Missing or ``None`` values fall back to the class defaults.
"""
return cls(**cls._kwargs_from_video_info(video_info))
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]: def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Return the subset of available encoders based on the specified video backend. """Return the subset of available encoders based on the specified video backend.
@@ -138,7 +164,9 @@ class VideoEncoderConfig:
require_package("av", extra="dataset") require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options()) check_video_encoder_parameters_pyav(
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
)
def resolve_vcodec(self) -> None: def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder. """Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
@@ -218,6 +246,10 @@ class VideoEncoderConfig:
elif self.vcodec == "h264_qsv": elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf) set_if("global_quality", self.crf)
set_if("preset", self.preset) set_if("preset", self.preset)
elif self.vcodec == "ffv1":
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
# are not meaningful.
set_if("threads", encoder_threads)
else: else:
set_if("crf", self.crf) set_if("crf", self.crf)
set_if("preset", self.preset) set_if("preset", self.preset)
@@ -233,3 +265,75 @@ class VideoEncoderConfig:
def camera_encoder_defaults() -> VideoEncoderConfig: def camera_encoder_defaults() -> VideoEncoderConfig:
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults.""" """Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
return VideoEncoderConfig() return VideoEncoderConfig()
@dataclass
class DepthEncoderConfig(VideoEncoderConfig):
"""Encoder configuration for depth-map streams.
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``…) and adds the four parameters of the depth
quantizer.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"gray12le"``.
Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented
by quantum ``0``.
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
shift: Pre-log offset for numerical stability near zero.
use_log: ``True`` for logarithmic quantization (default; matches
sensor error profile), ``False`` for linear.
"""
vcodec: str = "hevc"
pix_fmt: str = "gray12le"
depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG
_DEFAULT_CHANNELS: ClassVar[int] = 1
@classmethod
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
``shift`` / ``use_log``) on top of the base parser. Missing keys
fall back to the class defaults.
"""
kwargs = super()._kwargs_from_video_info(video_info)
video_info = video_info or {}
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{name}")
if value is not None:
kwargs[name] = value
return kwargs
def depth_encoder_defaults() -> DepthEncoderConfig:
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
return DepthEncoderConfig()
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
"""Build the appropriate encoder config from a feature's ``info`` block.
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
as a depth map and to :class:`VideoEncoderConfig`
otherwise.
Args:
video_info: A feature's ``info`` dict as persisted in ``info.json``,
or ``None`` (treated as an empty dict).
Returns:
A :class:`DepthEncoderConfig` for depth features, otherwise a
:class:`VideoEncoderConfig`.
"""
video_info = video_info or {}
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
return cls.from_video_info(video_info)
+9
View File
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import ( from .io_utils import (
get_file_size_in_mb, get_file_size_in_mb,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images, to_parquet_with_hf_images,
write_info, write_info,
write_stats, write_stats,
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root, aggr_root=dst_meta.root,
hf_features=hf_features, hf_features=hf_features,
concatenate=concatenate_data, concatenate=concatenate_data,
one_row_group_per_episode=True,
) )
# Record the mapping from source to actual destination # Record the mapping from source to actual destination
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
aggr_root: Path = None, aggr_root: Path = None,
hf_features: datasets.Features | None = None, hf_features: datasets.Features | None = None,
concatenate: bool = True, concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]: ) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints. """Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset. aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing. hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one. concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns: Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images: if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features) to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else: else:
df.to_parquet(dst_path) df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file) return idx, (dst_chunk, dst_file)
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
if contains_images: if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features) to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else: else:
final_df.to_parquet(target_path) final_df.to_parquet(target_path)
+13 -5
View File
@@ -506,8 +506,10 @@ def compute_episode_stats(
Each statistics dictionary contains min, max, mean, std, count, and quantiles. Each statistics dictionary contains min, max, mean, std, count, and quantiles.
Note: Note:
Image statistics are normalized to [0,1] range and have shape (3,1,1) for For 'image'/'video' features, stats are computed per channel and kept with a
per-channel values when dtype is 'image' or 'video'. leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
this rescaling and remain in their stored units.
""" """
if quantile_list is None: if quantile_list is None:
quantile_list = DEFAULT_QUANTILES quantile_list = DEFAULT_QUANTILES
@@ -531,8 +533,12 @@ def compute_episode_stats(
) )
if features[key]["dtype"] in ["image", "video"]: if features[key]["dtype"] in ["image", "video"]:
normalization_factor = (
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
)
ep_stats[key] = { ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
for k, v in ep_stats[key].items()
} }
return ep_stats return ep_stats
@@ -552,8 +558,10 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
if key == "count" and value.shape != (1,): if key == "count" and value.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.") raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1): if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.") raise ValueError(
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
)
def _assert_type_and_shape(stats_list: list[dict[str, dict]]): def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
+40 -8
View File
@@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from collections.abc import Callable from collections.abc import Callable, Iterable
from copy import deepcopy
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -337,6 +338,25 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities stored as videos.""" """Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"] return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def depth_keys(self) -> list[str]:
"""Keys to access depth-map modalities stored as videos or images.
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
"""
def _is_depth(ft: dict) -> bool:
info = ft.get("info") or {}
video_info = ft.get("video_info") or {}
return (
info.get("is_depth_map", False)
or info.get("video.is_depth_map", False)
or video_info.get("video.is_depth_map", False)
)
return [key for key, ft in self.features.items() if _is_depth(ft)]
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
@@ -580,29 +600,41 @@ class LeRobotDatasetMetadata:
def update_video_info( def update_video_info(
self, self,
video_key: str | None = None, video_key: str | None = None,
camera_encoder: VideoEncoderConfig | None = None, video_encoder: VideoEncoderConfig | None = None,
preserve_keys: Iterable[str] | None = None,
) -> None: ) -> None:
"""Populate per-feature video info in ``info.json``. """Populate or refresh per-feature video info in ``info.json``.
Warning: this function writes info from first episode videos, implicitly assuming that all videos have Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists. been encoded the same way. Also, this means it assumes the first episode exists.
Always re-probes the videos and overwrites existing info for every recomputed
key. ``preserve_keys`` lists keys whose existing values must be kept (e.g.
data-intrinsic entries like ``is_depth_map`` and depth quantization params)
instead of being recomputed.
Args: Args:
video_key: If provided, only update this video key. Otherwise update video_key: If provided, only update this video key. Otherwise update
all video keys in the dataset. all video keys in the dataset.
camera_encoder: Encoder configuration used to produce the video_encoder: Encoder configuration used to produce the
videos. When provided, its fields are recorded as videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived ``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`). ``video.*`` entries (see :func:`get_video_info`).
preserve_keys: Keys whose existing values are kept instead of being
recomputed. ``None`` (default) recomputes every key.
""" """
if video_key is not None and video_key not in self.video_keys: if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset") raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys video_keys = [video_key] if video_key is not None else self.video_keys
preserve_set = set(preserve_keys or ())
for key in video_keys: for key in video_keys:
if not self.features[key].get("info", None): existing = self.features[key].get("info") or {}
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder) new_info = get_video_info(video_path, video_encoder=video_encoder)
# Drop preserved keys so the existing values win on merge.
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
self.info.features[key]["info"] = {**existing, **new_info}
def update_chunk_settings( def update_chunk_settings(
self, self,
@@ -709,7 +741,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False) obj.root.mkdir(parents=True, exist_ok=False)
features = {**features, **DEFAULT_FEATURES} features = {**deepcopy(features), **DEFAULT_FEATURES}
_validate_feature_names(features) _validate_feature_names(features)
obj.tasks = None obj.tasks = None
+26
View File
@@ -22,7 +22,10 @@ from pathlib import Path
import datasets import datasets
import torch import torch
from lerobot.configs.video import DepthEncoderConfig
from .dataset_metadata import LeRobotDatasetMetadata from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import ( from .feature_utils import (
check_delta_timestamps, check_delta_timestamps,
get_delta_indices, get_delta_indices,
@@ -51,6 +54,7 @@ class DatasetReader:
delta_timestamps: dict[str, list[float]] | None, delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None, image_transforms: Callable | None,
return_uint8: bool = False, return_uint8: bool = False,
depth_output_unit: str = "mm",
): ):
"""Initialize the reader with metadata, filtering, and transform config. """Initialize the reader with metadata, filtering, and transform config.
@@ -68,6 +72,10 @@ class DatasetReader:
relative timestamp offsets for temporal context windows. relative timestamp offsets for temporal context windows.
image_transforms: Optional torchvision v2 transform applied to image_transforms: Optional torchvision v2 transform applied to
visual features. visual features.
return_uint8: If True, return RGB video frames as raw uint8 tensors
instead of normalized float32.
depth_output_unit: Physical unit depth maps are dequantized to
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
""" """
self._meta = meta self._meta = meta
self.root = root self.root = root
@@ -76,6 +84,7 @@ class DatasetReader:
self._video_backend = video_backend self._video_backend = video_backend
self._image_transforms = image_transforms self._image_transforms = image_transforms
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
self.hf_dataset: datasets.Dataset | None = None self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None self._absolute_to_relative_idx: dict[int, int] | None = None
@@ -86,6 +95,12 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool: def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient.""" """Attempt to load from local cache. Returns True if data is sufficient."""
try: try:
@@ -247,7 +262,18 @@ class DatasetReader:
self._tolerance_s, self._tolerance_s,
self._video_backend, self._video_backend,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
is_depth=vid_key in self._meta.depth_keys,
) )
if vid_key in self._meta.depth_keys:
depth_encoder = self._depth_encoder_configs[vid_key]
frames = dequantize_depth(
frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_unit=self._depth_output_unit,
)
return vid_key, frames.squeeze(0) return vid_key, frames.squeeze(0)
items = list(query_timestamps.items()) items = list(query_timestamps.items())
+107 -65
View File
@@ -27,6 +27,7 @@ import logging
import shutil import shutil
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets import datasets
@@ -36,7 +37,14 @@ import pyarrow.parquet as pq
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
encoder_config_from_video_info,
)
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
from lerobot.utils.utils import flatten_dict from lerobot.utils.utils import flatten_dict
@@ -47,6 +55,7 @@ from .compute_stats import (
compute_relative_action_stats, compute_relative_action_stats,
) )
from .dataset_metadata import LeRobotDatasetMetadata from .dataset_metadata import LeRobotDatasetMetadata
from .image_writer import write_image
from .io_utils import ( from .io_utils import (
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
load_episodes, load_episodes,
@@ -61,12 +70,13 @@ from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
DEPTH_FILE_PATTERN,
IMAGE_FILE_PATTERN,
VIDEO_DIR, VIDEO_DIR,
update_chunk_file_indices, update_chunk_file_indices,
) )
from .video_utils import ( from .video_utils import (
encode_video_frames, encode_video_frames,
get_video_info,
reencode_video, reencode_video,
) )
@@ -600,7 +610,7 @@ def _keep_episodes_from_video_with_av(
output_path: Path, output_path: Path,
episodes_to_keep: list[tuple[int, int]], episodes_to_keep: list[tuple[int, int]],
fps: float, fps: float,
camera_encoder: VideoEncoderConfig, video_encoder: VideoEncoderConfig,
) -> None: ) -> None:
"""Keep only specified episodes from a video file using PyAV. """Keep only specified episodes from a video file using PyAV.
@@ -614,7 +624,7 @@ def _keep_episodes_from_video_with_av(
Ranges are half-open intervals: [start_frame, end_frame), where start_frame Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive. is inclusive and end_frame is exclusive.
fps: Frame rate of the video. fps: Frame rate of the video.
camera_encoder: Video encoder settings used to re-encode the kept frames. video_encoder: Video encoder settings used to re-encode the kept frames.
""" """
from fractions import Fraction from fractions import Fraction
@@ -639,13 +649,13 @@ def _keep_episodes_from_video_with_av(
# Convert fps to Fraction for PyAV compatibility. # Convert fps to Fraction for PyAV compatibility.
fps_fraction = Fraction(fps).limit_denominator(1000) fps_fraction = Fraction(fps).limit_denominator(1000)
codec_options = camera_encoder.get_codec_options(as_strings=True) codec_options = video_encoder.get_codec_options(as_strings=True)
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options) v_out = out.add_stream(video_encoder.vcodec, rate=fps_fraction, options=codec_options)
# PyAV type stubs don't distinguish video streams from audio/subtitle streams. # PyAV type stubs don't distinguish video streams from audio/subtitle streams.
v_out.width = v_in.codec_context.width v_out.width = v_in.codec_context.width
v_out.height = v_in.codec_context.height v_out.height = v_in.codec_context.height
v_out.pix_fmt = camera_encoder.pix_fmt v_out.pix_fmt = video_encoder.pix_fmt
# Set time_base to match the frame rate for proper timestamp handling. # Set time_base to match the frame rate for proper timestamp handling.
v_out.time_base = Fraction(1, int(fps)) v_out.time_base = Fraction(1, int(fps))
@@ -732,7 +742,7 @@ def _copy_and_reindex_videos(
for video_key in src_dataset.meta.video_keys: for video_key in src_dataset.meta.video_keys:
logging.info(f"Processing videos for {video_key}") logging.info(f"Processing videos for {video_key}")
camera_encoder = VideoEncoderConfig.from_video_info( video_encoder = encoder_config_from_video_info(
src_dataset.meta.info.features.get(video_key, {}).get("info") src_dataset.meta.info.features.get(video_key, {}).get("info")
) )
@@ -816,7 +826,7 @@ def _copy_and_reindex_videos(
dst_video_path, dst_video_path,
episodes_to_keep_ranges, episodes_to_keep_ranges,
src_dataset.meta.fps, src_dataset.meta.fps,
camera_encoder, video_encoder,
) )
cumulative_ts = 0.0 cumulative_ts = 0.0
@@ -1101,7 +1111,9 @@ def _copy_episodes_metadata_and_stats(
if dst_meta.video_keys and src_dataset.meta.video_keys: if dst_meta.video_keys and src_dataset.meta.video_keys:
for key in dst_meta.video_keys: for key in dst_meta.video_keys:
if key in src_dataset.meta.features: if key in src_dataset.meta.features:
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {}) dst_meta.info.features[key]["info"] = deepcopy(
src_dataset.meta.info.features[key].get("info", {})
)
write_info(dst_meta.info, dst_meta.root) write_info(dst_meta.info, dst_meta.root)
@@ -1150,15 +1162,15 @@ def _save_episode_images_for_video(
# Get all items for this episode # Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Define function to save a single image # Define function to save a single image
def save_single_image(i_item_tuple): def save_single_image(i_item_tuple):
i, item = i_item_tuple i, item = i_item_tuple
img = item[img_key] write_image(item[img_key], imgs_dir / frame_pattern.format(frame_index=i))
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset)) items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -1190,13 +1202,14 @@ def _save_batch_episodes_images(
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns(img_key) imgs_dataset = hf_dataset.select_columns(img_key)
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Define function to save a single image with global frame index # Define function to save a single image with global frame index
# Defined once outside the loop to avoid repeated closure creation # Defined once outside the loop to avoid repeated closure creation
def save_single_image(i_item_tuple, base_frame_idx, img_key_param): def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
i, item = i_item_tuple i, item = i_item_tuple
img = item[img_key_param] write_image(item[img_key_param], imgs_dir / frame_pattern.format(frame_index=base_frame_idx + i))
# Use global frame index for naming
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
return i return i
episode_durations = [] episode_durations = []
@@ -1287,7 +1300,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int], episode_indices: list[int],
temp_dir: Path, temp_dir: Path,
fps: int, fps: int,
camera_encoder: VideoEncoderConfig, video_encoder: VideoEncoderConfig,
num_calibration_frames: int = 30, num_calibration_frames: int = 30,
) -> float: ) -> float:
"""Estimate MB per frame by encoding a small calibration sample. """Estimate MB per frame by encoding a small calibration sample.
@@ -1301,7 +1314,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed. episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files. temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding. fps: Frames per second for video encoding.
camera_encoder: Video encoder settings used for calibration encoding. video_encoder: Video encoder settings used for calibration encoding.
num_calibration_frames: Number of frames to use for calibration (default: 30). num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns: Returns:
@@ -1326,10 +1339,11 @@ def _estimate_frame_size_via_calibration(
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
sample_indices = range(from_idx, from_idx + num_frames) sample_indices = range(from_idx, from_idx + num_frames)
# Save calibration frames # Save calibration frames using the suffix/format the encoder expects.
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
for i, idx in enumerate(sample_indices): for i, idx in enumerate(sample_indices):
img = hf_dataset[idx][img_key] write_image(hf_dataset[idx][img_key], calibration_dir / frame_pattern.format(frame_index=i))
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
# Encode calibration video # Encode calibration video
calibration_video_path = calibration_dir / "calibration.mp4" calibration_video_path = calibration_dir / "calibration.mp4"
@@ -1337,7 +1351,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir, imgs_dir=calibration_dir,
video_path=calibration_video_path, video_path=calibration_video_path,
fps=fps, fps=fps,
camera_encoder=camera_encoder, video_encoder=video_encoder,
overwrite=True, overwrite=True,
) )
@@ -1610,6 +1624,7 @@ def recompute_stats(
raise ValueError(f"No parquet files found in {data_dir}") raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = [] all_episode_stats = []
# TODO: enable image and video stats re-computation
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]] numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"): for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
@@ -1656,6 +1671,7 @@ def convert_image_to_video_dataset(
output_dir: Path | None = None, output_dir: Path | None = None,
repo_id: str | None = None, repo_id: str | None = None,
camera_encoder: VideoEncoderConfig | None = None, camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
episode_indices: list[int] | None = None, episode_indices: list[int] | None = None,
num_workers: int = 4, num_workers: int = 4,
max_episodes_per_batch: int | None = None, max_episodes_per_batch: int | None = None,
@@ -1667,21 +1683,32 @@ def convert_image_to_video_dataset(
LeRobot dataset structure with videos stored in chunked MP4 files. LeRobot dataset structure with videos stored in chunked MP4 files.
Args: Args:
dataset: The source LeRobot dataset with images dataset: The source LeRobot dataset with images.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. output_dir: Root directory where the converted dataset will be stored. When
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. ``None``, defaults to ``$HF_LEROBOT_HOME/repo_id``. Equivalent to
camera_encoder: Video encoder settings ``new_root`` in ``EditDatasetConfig``.
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`). repo_id: Converted dataset identifier. Equivalent to ``new_repo_id`` in
episode_indices: List of episode indices to convert (None = all episodes) ``EditDatasetConfig``.
num_workers: Number of threads for parallel processing (default: 4) camera_encoder: Video encoder settings applied to RGB cameras. When ``None``,
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit) :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit) depth_encoder: Video encoder settings applied to depth-map cameras, including
the quantization parameters persisted to the dataset metadata. When
``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
episode_indices: Episode indices to convert. When ``None``, all episodes are
converted.
num_workers: Number of threads for parallel processing.
max_episodes_per_batch: Maximum episodes per video batch, to bound memory use.
``None`` means no limit.
max_frames_per_batch: Maximum frames per video batch, to bound memory use.
``None`` means no limit.
Returns: Returns:
New LeRobotDataset with images encoded as videos A new :class:`LeRobotDataset` with images encoded as videos.
""" """
if camera_encoder is None: if camera_encoder is None:
camera_encoder = camera_encoder_defaults() camera_encoder = camera_encoder_defaults()
if depth_encoder is None:
depth_encoder = depth_encoder_defaults()
# Check that it's an image dataset # Check that it's an image dataset
if len(dataset.meta.video_keys) > 0: if len(dataset.meta.video_keys) > 0:
@@ -1706,10 +1733,7 @@ def convert_image_to_video_dataset(
logging.info( logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
) )
logging.info( logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
)
# Create new features dict, converting image features to video features # Create new features dict, converting image features to video features
new_features = {} new_features = {}
@@ -1771,6 +1795,8 @@ def convert_image_to_video_dataset(
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices} episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
for img_key in tqdm(img_keys, desc="Processing cameras"): for img_key in tqdm(img_keys, desc="Processing cameras"):
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
# Estimate size per frame by encoding a small calibration sample # Estimate size per frame by encoding a small calibration sample
# This provides accurate compression ratio for the specific codec parameters # This provides accurate compression ratio for the specific codec parameters
size_per_frame_mb = _estimate_frame_size_via_calibration( size_per_frame_mb = _estimate_frame_size_via_calibration(
@@ -1779,7 +1805,7 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices, episode_indices=episode_indices,
temp_dir=temp_dir, temp_dir=temp_dir,
fps=fps, fps=fps,
camera_encoder=camera_encoder, video_encoder=target_encoder,
) )
logging.info(f"Processing camera: {img_key}") logging.info(f"Processing camera: {img_key}")
@@ -1821,7 +1847,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir, imgs_dir=imgs_dir,
video_path=video_path, video_path=video_path,
fps=fps, fps=fps,
camera_encoder=camera_encoder, video_encoder=target_encoder,
overwrite=True, overwrite=True,
) )
@@ -1860,16 +1886,11 @@ def convert_image_to_video_dataset(
new_meta.info.total_tasks = dataset.meta.total_tasks new_meta.info.total_tasks = dataset.meta.total_tasks
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"} new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos) # Update video info for all image keys (now videos). They are registered as
# We need to manually set video info since update_video_info() checks video_keys first # video features above, so update_video_info populates their (still-empty) info.
for img_key in img_keys: for img_key in img_keys:
if not new_meta.features[img_key].get("info", None): target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
video_path = new_meta.root / new_meta.video_path.format( new_meta.update_video_info(video_key=img_key, video_encoder=target_encoder)
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info.features[img_key]["info"] = get_video_info(
video_path, camera_encoder=camera_encoder
)
write_info(new_meta.info, new_meta.root) write_info(new_meta.info, new_meta.root)
@@ -1896,11 +1917,11 @@ def convert_image_to_video_dataset(
def _reencode_video_worker(args: tuple) -> Path: def _reencode_video_worker(args: tuple) -> Path:
"""Picklable worker for :func:`reencode_dataset`'s process pool.""" """Picklable worker for :func:`reencode_dataset`'s process pool."""
video_path, camera_encoder, encoder_threads = args video_path, video_encoder, encoder_threads = args
reencode_video( reencode_video(
input_video_path=video_path, input_video_path=video_path,
output_video_path=video_path, output_video_path=video_path,
camera_encoder=camera_encoder, video_encoder=video_encoder,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
overwrite=True, overwrite=True,
) )
@@ -1909,7 +1930,8 @@ def _reencode_video_worker(args: tuple) -> Path:
def reencode_dataset( def reencode_dataset(
dataset: LeRobotDataset, dataset: LeRobotDataset,
camera_encoder: VideoEncoderConfig, camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
num_workers: int | None = None, num_workers: int | None = None,
) -> LeRobotDataset: ) -> LeRobotDataset:
@@ -1920,8 +1942,11 @@ def reencode_dataset(
Args: Args:
dataset: An existing :class:`LeRobotDataset` whose videos will be dataset: An existing :class:`LeRobotDataset` whose videos will be
re-encoded. re-encoded.
camera_encoder: Target encoder configuration applied to every video camera_encoder: Target encoder configuration applied to every RGB video
file. file. If ``None``, re-encoding is skipped for RGB videos.
depth_encoder: Target encoder configuration applied to every depth video
file. If ``None``, re-encoding is skipped for depth videos.
Quantization parameters will not override the ones in the current dataset.
encoder_threads: Per-encoder thread count forwarded to encoder_threads: Per-encoder thread count forwarded to
:func:`reencode_video`. ``None`` lets the codec decide. :func:`reencode_video`. ``None`` lets the codec decide.
num_workers: Number of parallel processes. ``None`` or ``0`` means num_workers: Number of parallel processes. ``None`` or ``0`` means
@@ -1933,23 +1958,35 @@ def reencode_dataset(
on disk. on disk.
""" """
meta = dataset.meta meta = dataset.meta
video_paths_list = [] video_keys_encoders_dict = {}
video_keys_paths_dict = {}
if camera_encoder is None and depth_encoder is None:
raise ValueError("Either camera_encoder or depth_encoder must be provided")
# Only re-encode if the videos are not already encoded with the given video encoding parameters # Only re-encode if the videos are not already encoded with the given video encoding parameters
for video_key in meta.video_keys: for video_key in meta.video_keys:
current_info = meta.info.features[video_key].get("info", {}) current_info = meta.info.features[video_key].get("info", {})
current_encoder = VideoEncoderConfig.from_video_info(current_info) current_encoder = encoder_config_from_video_info(current_info)
if current_encoder != camera_encoder: target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4")) if target_encoder is None:
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
elif current_encoder != target_encoder:
video_keys_paths_dict[video_key] = list((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
video_keys_encoders_dict[video_key] = target_encoder
else: else:
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.") logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
if len(video_paths_list) == 0: if len(video_keys_paths_dict) == 0:
logging.warning("Dataset has no videos to re-encode.") logging.warning("Dataset has no videos to re-encode.")
return dataset return dataset
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}") logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list] worker_args = [
(path, encoder, encoder_threads)
for video_key, encoder in video_keys_encoders_dict.items()
for path in video_keys_paths_dict[video_key]
]
if num_workers and num_workers > 1: if num_workers and num_workers > 1:
with ProcessPoolExecutor(max_workers=num_workers) as pool: with ProcessPoolExecutor(max_workers=num_workers) as pool:
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args] futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
@@ -1963,10 +2000,15 @@ def reencode_dataset(
for args in tqdm(worker_args, desc="Re-encoding videos"): for args in tqdm(worker_args, desc="Re-encoding videos"):
_reencode_video_worker(args) _reencode_video_worker(args)
# Refresh video info in metadata for every video key. # Refresh video info in metadata for every re-encoded key. Re-encoding only
for vid_key in meta.video_keys: # changes codec/container params, so for depth videos we preserve ``is_depth_map``
video_path = meta.root / meta.get_video_file_path(0, vid_key) # and the depth quantization params (``video.depth_min`` / ``video.depth_max`` /
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder) # ...), which describe the data rather than the codec and must survive a transcode.
# RGB videos pass an empty set: still a refresh, but nothing to preserve.
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
for video_key, encoder in video_keys_encoders_dict.items():
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else set()
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
write_info(meta.info, meta.root) write_info(meta.info, meta.root)
logging.info("Dataset metadata updated.") logging.info("Dataset metadata updated.")
+41 -12
View File
@@ -31,7 +31,12 @@ import PIL.Image
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from .compute_stats import compute_episode_stats from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata from .dataset_metadata import LeRobotDatasetMetadata
@@ -48,6 +53,7 @@ from .io_utils import (
write_info, write_info,
) )
from .utils import ( from .utils import (
DEFAULT_DEPTH_PATH,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
update_chunk_file_indices, update_chunk_file_indices,
@@ -67,17 +73,22 @@ def _encode_video_worker(
episode_index: int, episode_index: int,
root: Path, root: Path,
fps: int, fps: int,
camera_encoder: VideoEncoderConfig | None = None, video_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
) -> Path: ) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) path_template = (
DEFAULT_DEPTH_PATH
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
else DEFAULT_IMAGE_PATH
)
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent img_dir = (root / fpath).parent
encode_video_frames( encode_video_frames(
img_dir, img_dir,
temp_path, temp_path,
fps, fps,
camera_encoder=camera_encoder, video_encoder=video_encoder,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
overwrite=True, overwrite=True,
) )
@@ -97,6 +108,7 @@ class DatasetWriter:
meta: LeRobotDatasetMetadata, meta: LeRobotDatasetMetadata,
root: Path, root: Path,
camera_encoder: VideoEncoderConfig | None, camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_threads: int | None, encoder_threads: int | None,
batch_encoding_size: int, batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None, streaming_encoder: StreamingVideoEncoder | None = None,
@@ -108,8 +120,11 @@ class DatasetWriter:
meta: Dataset metadata instance (used for feature schema, chunk meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence). settings, and episode persistence).
root: Local dataset root directory. root: Local dataset root directory.
camera_encoder: Video encoder settings applied to all cameras. camera_encoder: Video encoder settings applied to RGB cameras. When
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`. ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings applied to depth cameras, including
the quantization parameters. When ``None``,
:func:`~lerobot.configs.video.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None`` encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide. lets the codec decide.
batch_encoding_size: Number of episodes to accumulate before batch_encoding_size: Number of episodes to accumulate before
@@ -121,6 +136,7 @@ class DatasetWriter:
self._meta = meta self._meta = meta
self._root = root self._root = root
self._camera_encoder = camera_encoder or camera_encoder_defaults() self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder self._streaming_encoder = streaming_encoder
@@ -145,7 +161,8 @@ class DatasetWriter:
return ep_buffer return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format( path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
fpath = path_template.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return self._root / fpath return self._root / fpath
@@ -195,6 +212,7 @@ class DatasetWriter:
if frame_index == 0 and self._streaming_encoder is not None: if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode( self._streaming_encoder.start_episode(
video_keys=list(self._meta.video_keys), video_keys=list(self._meta.video_keys),
depth_video_keys=list(self._meta.depth_keys),
temp_dir=self._root, temp_dir=self._root,
) )
@@ -282,10 +300,13 @@ class DatasetWriter:
if use_streaming: if use_streaming:
streaming_results = self._streaming_encoder.finish_episode() streaming_results = self._streaming_encoder.finish_episode()
for video_key in self._meta.video_keys: for video_key in self._meta.video_keys:
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
temp_path, video_stats = streaming_results[video_key] temp_path, video_stats = streaming_results[video_key]
if video_stats is not None: if video_stats is not None:
ep_stats[video_key] = { ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) k: v
if k == "count"
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
for k, v in video_stats.items() for k, v in video_stats.items()
} }
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
@@ -300,7 +321,9 @@ class DatasetWriter:
episode_index, episode_index,
self._root, self._root,
self._meta.fps, self._meta.fps,
self._camera_encoder, self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
self._encoder_threads, self._encoder_threads,
): video_key ): video_key
for video_key in self._meta.video_keys for video_key in self._meta.video_keys
@@ -511,7 +534,12 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded) # Update video info (only needed when first episode is encoded)
if episode_index == 0: if episode_index == 0:
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder) self._meta.update_video_info(
video_key,
video_encoder=self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
)
write_info(self._meta.info, self._meta.root) write_info(self._meta.info, self._meta.root)
metadata = { metadata = {
@@ -578,13 +606,14 @@ class DatasetWriter:
self.image_writer.wait_until_done() self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos.""" """Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
is_depth = video_key in self._meta.depth_keys
return _encode_video_worker( return _encode_video_worker(
video_key, video_key,
episode_index, episode_index,
self._root, self._root,
self._meta.fps, self._meta.fps,
self._camera_encoder, self._depth_encoder if is_depth else self._camera_encoder,
self._encoder_threads, self._encoder_threads,
) )
+256
View File
@@ -0,0 +1,256 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
"""
import math
from typing import Literal
import av
import numpy as np
import torch
from numpy.typing import NDArray
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
DEPTH_QMAX,
)
from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0
_UINT16_MAX = 65535
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
if depth_min + shift <= 0:
raise ValueError(
f"depth_min + shift must be positive for logarithmic quantization, "
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
)
def _depth_input_to_float32_and_unit(
depth: NDArray[np.integer] | NDArray[np.floating],
input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = (
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
)
return depth.astype(np.float32, order="K"), resolved_unit
def quantize_depth(
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
Depth maps are packed into 12-bit integer frames so they fit in standard
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
Logarithmic quantization is the default because it allocates more quanta
to near-range depth, which matches the (1/depth) error profile of typical
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
**Input units**:
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
- ``input_unit="mm"``: interpret input values as millimetres.
- ``input_unit="m"``: interpret input values as metres.
Quantization math runs in the **resolved input unit**.
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
Args:
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
depth_min: Depth (metres) at quantum ``0``.
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
use_log: If ``True`` (default), quantize in log space.
video_backend: Video backend to use for encoding. Defaults to "pyav".
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
Returns:
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
``[0, DEPTH_QMAX]``.
Raises:
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
depth = depth.squeeze()
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit.
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_u + shift_u))
log_max = math.log(float(depth_max_u + shift_u))
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
else:
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
write_u16_plane(frame.planes[0], quantized)
return frame
else:
return quantized
def dequantize_depth(
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = True,
output_channel_last: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
"""Inverse of :func:`quantize_depth`.
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
Accepted input layouts :
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
- ``(..., 1, H, W)`` — batched frames with channel-first.
- ``(..., H, W, 1)`` — batched frames with channel-last.
Output layout is determined by ``output_channel_last``.
Args:
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
``[depth_min, depth_max]``.
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
Returns:
Depth map in the requested unit and dtype.
Raises:
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if use_log:
_validate_log_quant_params(depth_min, shift)
if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=pix_fmt)
# Compute the scale and offset first.
depth_min_m = float(depth_min)
depth_max_m = float(depth_max)
shift_m = float(shift)
if use_log:
log_min = math.log(depth_min_m + shift_m)
log_max = math.log(depth_max_m + shift_m)
scale = (log_max - log_min) / DEPTH_QMAX
offset = log_min
else:
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
offset = depth_min_m
# ── Torch path: stay on the input device, single fp32 allocation. ────────
if isinstance(quantized, torch.Tensor):
if quantized.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
# Single allocation we own; everything else is in-place.
buf = quantized.to(dtype=torch.float32, copy=True)
buf.mul_(scale).add_(offset)
if use_log:
buf.exp_().sub_(shift_m)
buf.clamp_(depth_min_m, depth_max_m)
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
if output_unit == "m":
return buf if output_tensor else buf.cpu().numpy()
# mm path: round + clamp in float32, skipping the uint16 round-trip
# when returning a tensor (torch.uint16 is poorly supported).
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
if output_tensor:
return buf
return buf.cpu().numpy().astype(np.uint16, copy=False)
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
arr = np.asarray(quantized)
if arr.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
buf = np.empty(arr.shape, dtype=np.float32)
np.multiply(arr, scale, out=buf)
np.add(buf, offset, out=buf)
if use_log:
np.exp(buf, out=buf)
np.subtract(buf, shift_m, out=buf)
np.clip(buf, depth_min_m, depth_max_m, out=buf)
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
if output_unit == "m":
return torch.from_numpy(buf) if output_tensor else buf
np.multiply(buf, _MM_PER_METRE, out=buf)
np.rint(buf, out=buf)
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
if output_tensor:
# torch.uint16 support is very limited; return float32 millimetres.
return torch.from_numpy(buf)
return buf.astype(np.uint16, copy=False)
+1
View File
@@ -96,6 +96,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
revision=cfg.dataset.revision, revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend, video_backend=cfg.dataset.video_backend,
return_uint8=True, return_uint8=True,
depth_output_unit=cfg.dataset.depth_output_unit,
tolerance_s=cfg.tolerance_s, tolerance_s=cfg.tolerance_s,
) )
else: else:
+1 -1
View File
@@ -336,7 +336,7 @@ def validate_feature_image_or_video(
Args: Args:
name (str): The name of the feature. name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W). expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
value: The image data to validate. value: The image data to validate.
Returns: Returns:
+51 -5
View File
@@ -42,10 +42,41 @@ def safe_stop_image_writer(func):
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images """Convert a NumPy array to a PIL Image, preserving precision for grayscale.
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
Behaviour by shape:
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
The native dtype is preserved using the matching PIL mode
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
(existing behaviour, gated by ``range_check``).
Other shapes / channel counts raise ``NotImplementedError`` or
``ValueError``.
"""
# TODO(CarolinePascal): 4 dimensions RGB-D images
if image_array.ndim not in (2, 3):
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
if image_array.ndim == 2:
if image_array.dtype not in [np.uint16, np.float32]:
raise ValueError(
f"Unsupported single-channel image dtype: {image_array.dtype}. "
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
)
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
# 3D path: must be RGB (3 channels), channels-first or channels-last.
if image_array.shape[0] == 3: if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C) # Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0) image_array = image_array.transpose(1, 2, 0)
@@ -71,13 +102,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array) return PIL.Image.fromarray(image_array)
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
"""
suffix = Path(fpath).suffix.lower()
if suffix == ".png":
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1): def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
""" """
Saves a NumPy array or PIL Image to a file. Saves a NumPy array or PIL Image to a file.
This function handles both NumPy arrays and PIL Image objects, converting This function handles both NumPy arrays and PIL Image objects, converting
the former to a PIL Image before saving. It includes error handling for the former to a PIL Image before saving. It includes error handling for
the save operation. the save operation. The output format is inferred from the *fpath*
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
→ lossless raw depth maps (TIFF).
Args: Args:
image (np.ndarray | PIL.Image.Image): The image data to save. image (np.ndarray | PIL.Image.Image): The image data to save.
@@ -101,7 +147,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
img = image img = image
else: else:
raise TypeError(f"Unsupported image type: {type(image)}") raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level) img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
except Exception as e: except Exception as e:
logger.error("Error writing image %s: %s", fpath, e) logger.error("Error writing image %s: %s", fpath, e)
+39 -10
View File
@@ -20,6 +20,7 @@ import datasets
import numpy as np import numpy as np
import pandas import pandas
import pandas as pd import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
@@ -153,7 +154,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns: Returns:
dict: The statistics dictionary with values cast to numpy arrays. dict: The statistics dictionary with values cast to numpy arrays.
""" """
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()} stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats) return unflatten_dict(stats)
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images( def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None: ) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. """Write a DataFrame with HF-encoded images to parquet, one row group per episode.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Args: Images are embedded into the arrow table first (``ParquetWriter.write_table``
df: DataFrame to write to parquet. does not embed external image files like ``Dataset.to_parquet`` does).
path: Path to write the parquet file. ``features`` types image columns as ``Image()`` in the parquet schema.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
""" """
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path) ds = embed_images(ds)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
def item_to_torch(item: dict) -> dict: def item_to_torch(item: dict) -> dict:
+25 -3
View File
@@ -24,7 +24,7 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from lerobot.configs import VideoEncoderConfig from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
@@ -58,8 +58,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
return_uint8: bool = False, return_uint8: bool = False,
depth_output_unit: str = "mm",
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None, camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
streaming_encoding: bool = False, streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30, encoder_queue_maxsize: int = 30,
@@ -186,6 +188,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` (codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
is used by the writer. is used by the writer.
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults`
is used by the writer.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide. codec decide.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
@@ -208,6 +213,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_video_backend() self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
self._batch_encoding_size = batch_encoding_size self._batch_encoding_size = batch_encoding_size
self._encoder_threads = encoder_threads self._encoder_threads = encoder_threads
@@ -248,6 +254,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
image_transforms=image_transforms, image_transforms=image_transforms,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
depth_output_unit=self._depth_output_unit,
) )
# Load actual data # Load actual data
@@ -273,6 +280,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = self._build_streaming_encoder( streaming_enc = self._build_streaming_encoder(
self.meta.fps, self.meta.fps,
camera_encoder, camera_encoder,
depth_encoder,
encoder_queue_maxsize, encoder_queue_maxsize,
encoder_threads, encoder_threads,
) )
@@ -280,6 +288,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
meta=self.meta, meta=self.meta,
root=self.root, root=self.root,
camera_encoder=camera_encoder, camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size, batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc, streaming_encoder=streaming_enc,
@@ -315,6 +324,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=self.delta_timestamps, delta_timestamps=self.delta_timestamps,
image_transforms=self.image_transforms, image_transforms=self.image_transforms,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
depth_output_unit=self._depth_output_unit,
) )
return self.reader return self.reader
@@ -322,12 +332,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _build_streaming_encoder( def _build_streaming_encoder(
fps: int, fps: int,
camera_encoder: VideoEncoderConfig | None, camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_queue_maxsize: int, encoder_queue_maxsize: int,
encoder_threads: int | None, encoder_threads: int | None,
) -> StreamingVideoEncoder: ) -> StreamingVideoEncoder:
return StreamingVideoEncoder( return StreamingVideoEncoder(
fps=fps, fps=fps,
camera_encoder=camera_encoder, camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
queue_maxsize=encoder_queue_maxsize, queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
) )
@@ -646,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None, camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
metadata_buffer_size: int = 10, metadata_buffer_size: int = 10,
streaming_encoding: bool = False, streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30, encoder_queue_maxsize: int = 30,
@@ -678,6 +691,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch-encoding videos. ``1`` means encode immediately. batch-encoding videos. ``1`` means encode immediately.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.). camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used. When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None`` encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide. lets the codec decide.
metadata_buffer_size: Number of episode metadata records to buffer metadata_buffer_size: Number of episode metadata records to buffer
@@ -712,6 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episodes = None obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend() obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._return_uint8 = False obj._return_uint8 = False
obj._depth_output_unit = "mm"
obj._batch_encoding_size = batch_encoding_size obj._batch_encoding_size = batch_encoding_size
obj._encoder_threads = encoder_threads obj._encoder_threads = encoder_threads
@@ -721,12 +737,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0: if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder( streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder, encoder_queue_maxsize, encoder_threads fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
) )
obj.writer = DatasetWriter( obj.writer = DatasetWriter(
meta=obj.meta, meta=obj.meta,
root=obj.root, root=obj.root,
camera_encoder=camera_encoder, camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size, batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc, streaming_encoder=streaming_enc,
@@ -750,6 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None, camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads: int = 0, image_writer_threads: int = 0,
@@ -779,6 +797,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch-encoding videos. batch-encoding videos.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.). camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used. When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None`` encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide. lets the codec decide.
image_writer_processes: Subprocesses for async image writing. image_writer_processes: Subprocesses for async image writing.
@@ -806,6 +826,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episodes = None obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend() obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._return_uint8 = False obj._return_uint8 = False
obj._depth_output_unit = "mm"
obj._batch_encoding_size = batch_encoding_size obj._batch_encoding_size = batch_encoding_size
if obj._requested_root is not None: if obj._requested_root is not None:
@@ -825,12 +846,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0: if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder( streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
) )
obj.writer = DatasetWriter( obj.writer = DatasetWriter(
meta=obj.meta, meta=obj.meta,
root=obj.root, root=obj.root,
camera_encoder=camera_encoder, camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size, batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc, streaming_encoder=streaming_enc,
+5 -3
View File
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]], initial_features: dict[PipelineFeatureType, dict[str, Any]],
*, *,
use_videos: bool = True, use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None, patterns: Sequence[str] | None = None,
) -> dict[str, dict]: ) -> dict[str, dict]:
""" """
Aggregates and filters pipeline features to create a dataset-ready features dictionary. Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `use_videos` and `patterns`, and finally (image or state), filters them based on `exclude_images` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset. formats them for use with a Hugging Face LeRobot Dataset.
Args: Args:
pipeline: The DataProcessorPipeline to apply. pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations. initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: If False, image features are excluded. use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
patterns: A sequence of regex patterns to filter action and state features. patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter. Image features are not affected by this filter.
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
) )
# 2. Apply filtering rules. # 2. Apply filtering rules.
if is_image and not use_videos: if is_image and exclude_images:
continue continue
if not is_image and not should_keep(key, compiled_patterns): if not is_image and not should_keep(key, compiled_patterns):
continue continue
+37 -2
View File
@@ -24,6 +24,7 @@ import logging
from typing import Any from typing import Any
import av import av
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,6 +32,22 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64") FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
@functools.cache
def get_pix_fmt_channels(pix_fmt: str) -> int:
"""Return the number of components (channels) for *pix_fmt*."""
return len(av.VideoFormat(pix_fmt).components)
@functools.cache @functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None: def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable.""" """PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
@@ -92,7 +109,7 @@ def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Opti
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option." f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
) from e ) from e
elif isinstance(value, (float, int)): elif isinstance(value, (float, int)):
num_val = value num_val = float(value)
else: else:
raise ValueError( raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option." f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
@@ -142,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
) )
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
"""Ensure *pix_fmt* can carry at least *channels* components."""
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
if pix_fmt_channels < channels:
raise ValueError(
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
f"but the source data has {channels} channel(s)."
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None: def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions.""" """Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec) supported_options = _get_codec_options_by_name(vcodec)
@@ -156,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
_check_option_value(vcodec, key, value, supported_options[key]) _check_option_value(vcodec, key, value, supported_options[key])
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None: def check_video_encoder_parameters_pyav(
vcodec: str,
pix_fmt: str,
codec_options: dict[str, Any],
channels: int | None = None,
) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build. """Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options` encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict). against PyAV (including numeric ``extra_options`` present in that dict).
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
No-op when ``config.vcodec`` isn't in the local FFmpeg build. No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises: Raises:
@@ -171,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options
if not options: if not options:
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build") raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
_check_pixel_format(vcodec, pix_fmt) _check_pixel_format(vcodec, pix_fmt)
if channels is not None:
_check_pix_fmt_channels(pix_fmt, channels)
_check_codec_options(vcodec, codec_options) _check_codec_options(vcodec, codec_options)
+4 -1
View File
@@ -87,11 +87,14 @@ DATA_DIR = "data"
VIDEO_DIR = "videos" VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
IMAGE_FILE_PATTERN = "frame-{frame_index:06d}.png"
DEPTH_FILE_PATTERN = "frame-{frame_index:06d}.tiff"
DEFAULT_TASKS_PATH = "meta/tasks.parquet" DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/" + IMAGE_FILE_PATTERN
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/" + DEPTH_FILE_PATTERN
LEGACY_EPISODES_PATH = "meta/episodes.jsonl" LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
+151 -71
View File
@@ -39,11 +39,16 @@ from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from lerobot.configs import ( from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig, VideoEncoderConfig,
camera_encoder_defaults, camera_encoder_defaults,
depth_encoder_defaults,
) )
from lerobot.utils.import_utils import get_safe_default_video_backend from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -53,6 +58,7 @@ def decode_video_frames(
tolerance_s: float, tolerance_s: float,
backend: str | None = None, backend: str | None = None,
return_uint8: bool = False, return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes video frames using the specified backend. Decodes video frames using the specified backend.
@@ -64,23 +70,35 @@ def decode_video_frames(
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
accepted for one release as an alias for "pyav" and will be removed in a future version. accepted for one release as an alias for "pyav" and will be removed in a future version.
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization. return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward. This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
Returns: Returns:
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True). torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
Currently supports torchcodec on cpu and pyav. Currently supports torchcodec on cpu and pyav.
""" """
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
)
if backend is None: if backend is None:
backend = get_safe_default_video_backend() backend = get_safe_default_video_backend()
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "pyav": elif backend == "pyav":
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8) return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
)
elif backend == "video_reader": elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.") logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8) return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
)
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
@@ -91,6 +109,7 @@ def decode_video_frames_pyav(
tolerance_s: float, tolerance_s: float,
log_loaded_timestamps: bool = False, log_loaded_timestamps: bool = False,
return_uint8: bool = False, return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video using PyAV. """Loads frames associated to the requested timestamps of a video using PyAV.
@@ -109,8 +128,9 @@ def decode_video_frames_pyav(
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame. decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level. log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
[0, 1] range. Otherwise, return float32 in [0, 1] range.
is_depth: Set to True if the video is a depth map (1 channel, uint12).
Returns: Returns:
torch.Tensor of shape (len(timestamps), C, H, W). torch.Tensor of shape (len(timestamps), C, H, W).
@@ -140,9 +160,13 @@ def decode_video_frames_pyav(
current_ts = float(frame.pts * stream.time_base) current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps: if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}") logger.info(f"frame loaded at timestamp={current_ts:.4f}")
# Convert to CHW uint8 to match torchcodec's output layout. if is_depth:
arr = frame.to_ndarray(format="rgb24") # H, W, 3 arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous()) loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
# Convert to CHW uint8 to match torchcodec's output layout.
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts) loaded_ts.append(current_ts)
if current_ts >= last_ts: if current_ts >= last_ts:
break break
@@ -185,7 +209,7 @@ def decode_video_frames_pyav(
f"number of queried timestamps ({len(timestamps)})" f"number of queried timestamps ({len(timestamps)})"
) )
if return_uint8: if return_uint8 or is_depth:
return closest_frames return closest_frames
# convert to the pytorch format which is float32 in [0,1] range (and channel first) # convert to the pytorch format which is float32 in [0,1] range (and channel first)
@@ -406,17 +430,38 @@ def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,
fps: int, fps: int,
camera_encoder: VideoEncoderConfig | None = None, video_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
*, *,
log_level: int | None = av.logging.WARNING, log_level: int | None = av.logging.WARNING,
overwrite: bool = False, overwrite: bool = False,
) -> None: ) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" """Encode a directory of image frames into an MP4 video.
if camera_encoder is None:
camera_encoder = camera_encoder_defaults() When ``video_encoder`` is a :class:`~lerobot.configs.video.DepthEncoderConfig`,
vcodec = camera_encoder.vcodec frames are read from ``.tiff`` files and quantized to 12-bit depth codes using the
pix_fmt = camera_encoder.pix_fmt encoder's ``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``; otherwise ``.png``
RGB frames are encoded directly.
Args:
imgs_dir: Directory containing the frames to encode, named ``frame-000000``
onwards (``.png`` for RGB, ``.tiff`` for depth).
video_path: Output path for the encoded ``.mp4`` file.
fps: Frame rate of the output video.
video_encoder: Encoder settings (codec, pixel format, quality, ...). When
``None``, :func:`camera_encoder_defaults` is used. Pass a
:class:`~lerobot.configs.video.DepthEncoderConfig` to encode depth frames.
encoder_threads: Per-encoder thread count forwarded to the codec. ``None``
lets the codec decide.
log_level: libav log level to set while encoding, or ``None`` to leave the
current logging configuration unchanged.
overwrite: When ``False`` and ``video_path`` already exists, skip encoding and
log a warning. When ``True``, re-encode and replace the existing file.
"""
if video_encoder is None:
video_encoder = camera_encoder_defaults()
vcodec = video_encoder.vcodec
pix_fmt = video_encoder.pix_fmt
video_path = Path(video_path) video_path = Path(video_path)
imgs_dir = Path(imgs_dir) imgs_dir = Path(imgs_dir)
@@ -428,17 +473,19 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True) video_path.parent.mkdir(parents=True, exist_ok=True)
# Get input frames # Get input frames
template = "frame-" + ("[0-9]" * 6) + ".png" is_depth = isinstance(video_encoder, DepthEncoderConfig)
suffix = ".png" if not is_depth else ".tiff"
template = "frame-" + ("[0-9]" * 6) + suffix
input_list = sorted( input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0]) glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
) )
if len(input_list) == 0: if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.") raise FileNotFoundError(f"No images with suffix {suffix} found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image: with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size width, height = dummy_image.size
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True) video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
# Set logging level # Set logging level
if log_level is not None: if log_level is not None:
@@ -455,8 +502,19 @@ def encode_video_frames(
# Loop through input frames and encode them # Loop through input frames and encode them
for input_data in input_list: for input_data in input_list:
with Image.open(input_data) as input_image: with Image.open(input_data) as input_image:
input_image = input_image.convert("RGB") if is_depth:
input_frame = av.VideoFrame.from_image(input_image) input_frame = quantize_depth(
np.array(input_image),
depth_min=video_encoder.depth_min,
depth_max=video_encoder.depth_max,
shift=video_encoder.shift,
use_log=video_encoder.use_log,
pix_fmt=video_encoder.pix_fmt,
video_backend="pyav",
)
else:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame) packet = output_stream.encode(input_frame)
if packet: if packet:
output.mux(packet) output.mux(packet)
@@ -477,7 +535,7 @@ def encode_video_frames(
def reencode_video( def reencode_video(
input_video_path: Path | str, input_video_path: Path | str,
output_video_path: Path | str, output_video_path: Path | str,
camera_encoder: VideoEncoderConfig | None = None, video_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING, log_level: int | None = av.logging.WARNING,
overwrite: bool = False, overwrite: bool = False,
@@ -489,7 +547,7 @@ def reencode_video(
Args: Args:
input_video_path: Existing video file to read. input_video_path: Existing video file to read.
output_video_path: Path for the re-encoded file. output_video_path: Path for the re-encoded file.
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`. video_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`. encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING. log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning. overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
@@ -497,7 +555,7 @@ def reencode_video(
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive). end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
""" """
camera_encoder = camera_encoder or camera_encoder_defaults() video_encoder = video_encoder or camera_encoder_defaults()
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0): if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.") raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
@@ -512,9 +570,9 @@ def reencode_video(
output_video_path.parent.mkdir(parents=True, exist_ok=True) output_video_path.parent.mkdir(parents=True, exist_ok=True)
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True) video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
vcodec = camera_encoder.vcodec vcodec = video_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt pix_fmt = video_encoder.pix_fmt
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file: with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name tmp_output_video_path = tmp_named_file.name
@@ -696,22 +754,21 @@ class _CameraEncoderThread(threading.Thread):
self, self,
video_path: Path, video_path: Path,
fps: int, fps: int,
vcodec: str, video_encoder: VideoEncoderConfig,
pix_fmt: str,
codec_options: dict[str, str],
frame_queue: queue.Queue, frame_queue: queue.Queue,
result_queue: queue.Queue, result_queue: queue.Queue,
stop_event: threading.Event, stop_event: threading.Event,
encoder_threads: int | None = None,
): ):
super().__init__(daemon=True) super().__init__(daemon=True)
self.video_path = video_path self.video_path = video_path
self.fps = fps self.fps = fps
self.vcodec = vcodec self.video_encoder = video_encoder
self.pix_fmt = pix_fmt self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
self.codec_options = codec_options
self.frame_queue = frame_queue self.frame_queue = frame_queue
self.result_queue = result_queue self.result_queue = result_queue
self.stop_event = stop_event self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None: def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width from .compute_stats import RunningQuantileStats, auto_downsample_height_width
@@ -736,12 +793,12 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close # Sentinel: flush and close
break break
# Ensure HWC uint8 numpy array # Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
if isinstance(frame_data, np.ndarray): if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3: if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
# CHW -> HWC # CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0) frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8: if not self.is_depth and frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8) frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height) # Open container on first frame (to get width/height)
@@ -749,15 +806,29 @@ class _CameraEncoderThread(threading.Thread):
height, width = frame_data.shape[:2] height, width = frame_data.shape[:2]
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True) Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w") container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options) output_stream = container.add_stream(
output_stream.pix_fmt = self.pix_fmt self.video_encoder.vcodec,
self.fps,
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
)
output_stream.pix_fmt = self.video_encoder.pix_fmt
output_stream.width = width output_stream.width = width
output_stream.height = height output_stream.height = height
output_stream.time_base = Fraction(1, self.fps) output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps # Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data) if not self.is_depth:
video_frame = av.VideoFrame.from_image(pil_img) pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
else:
video_frame = quantize_depth(
frame_data,
depth_min=self.video_encoder.depth_min,
depth_max=self.video_encoder.depth_max,
shift=self.video_encoder.shift,
use_log=self.video_encoder.use_log,
video_backend=self.video_encoder.video_backend,
)
video_frame.pts = frame_count video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps) video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame) packet = output_stream.encode(video_frame)
@@ -816,21 +887,26 @@ class StreamingVideoEncoder:
self, self,
fps: int, fps: int,
camera_encoder: VideoEncoderConfig | None = None, camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
queue_maxsize: int = 30, queue_maxsize: int = 30,
encoder_threads: int | None = None, encoder_threads: int | None = None,
): ):
""" """
Args: Args:
fps: Frames per second for the output videos. fps: Frames per second for the output videos.
camera_encoder: Video encoder settings applied to all cameras. camera_encoder: Video encoder settings applied to all RGB cameras.
When ``None``, :func:`camera_encoder_defaults` is used. When ``None``, :func:`camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global setting). depth_encoder: Video encoder settings applied to all depth cameras,
``None`` lets the codec decide. including the depth quantization parameters. When ``None``,
:func:`depth_encoder_defaults` is used.
queue_maxsize: Max frames to buffer per camera before queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames. back-pressure drops frames.
encoder_threads: Number of encoder threads (global setting).
``None`` lets the codec decide.
""" """
self.fps = fps self.fps = fps
self._camera_encoder = camera_encoder or camera_encoder_defaults() self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads self._encoder_threads = encoder_threads
self.queue_maxsize = queue_maxsize self.queue_maxsize = queue_maxsize
@@ -843,18 +919,25 @@ class StreamingVideoEncoder:
self._episode_active = False self._episode_active = False
self._closed = False self._closed = False
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None: def start_episode(
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
) -> None:
"""Start encoder threads for a new episode. """Start encoder threads for a new episode.
Args: Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"]) video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
temp_dir: Base directory for temporary MP4 files temp_dir: Base directory for temporary MP4 files
depth_video_keys: List of video or image feature keys that carry depth maps (e.g.
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
""" """
if self._episode_active: if self._episode_active:
self.cancel_episode() self.cancel_episode()
self._dropped_frames.clear() self._dropped_frames.clear()
if depth_video_keys is None:
depth_video_keys = []
for video_key in video_keys: for video_key in video_keys:
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize) frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
result_queue: queue.Queue = queue.Queue(maxsize=1) result_queue: queue.Queue = queue.Queue(maxsize=1)
@@ -863,17 +946,15 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir)) temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4" video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
vcodec = self._camera_encoder.vcodec encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
encoder_thread = _CameraEncoderThread( encoder_thread = _CameraEncoderThread(
video_path=video_path, video_path=video_path,
fps=self.fps, fps=self.fps,
vcodec=vcodec, video_encoder=encoder,
pix_fmt=self._camera_encoder.pix_fmt,
codec_options=codec_options,
frame_queue=frame_queue, frame_queue=frame_queue,
result_queue=result_queue, result_queue=result_queue,
stop_event=stop_event, stop_event=stop_event,
encoder_threads=self._encoder_threads,
) )
encoder_thread.start() encoder_thread.start()
@@ -1080,15 +1161,23 @@ def get_audio_info(video_path: Path | str) -> dict:
def get_video_info( def get_video_info(
video_path: Path | str, video_path: Path | str,
camera_encoder: VideoEncoderConfig | None = None, video_encoder: VideoEncoderConfig | None = None,
) -> dict: ) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``. """Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args: Args:
video_path: Path to the encoded video file to probe. video_path: Path to the encoded video file to probe.
camera_encoder: If provided, record the exact encoder settings used to encode this video_encoder: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence — encoder fields are only written for keys video. Stream-derived values take precedence — encoder fields are only written for keys
not already populated from the video file itself. not already populated from the video file itself. When a
:class:`~lerobot.configs.video.DepthEncoderConfig` is passed, the depth
quantization parameters (``depth_min`` / ``depth_max`` / ``shift`` /
``use_log``) are recorded so frames can be dequantized on read.
Returns:
The ``video.*`` / ``audio.*`` info dict, including ``is_depth_map`` which is
``True`` only when ``video_encoder`` is a
:class:`~lerobot.configs.video.DepthEncoderConfig`.
""" """
logging.getLogger("libav").setLevel(av.logging.WARNING) logging.getLogger("libav").setLevel(av.logging.WARNING)
@@ -1106,13 +1195,10 @@ def get_video_info(
video_info["video.width"] = video_stream.width video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate # Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate) video_info["video.fps"] = int(video_stream.base_rate)
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
# Reset logging level # Reset logging level
av.logging.restore_default_callback() av.logging.restore_default_callback()
@@ -1121,27 +1207,18 @@ def get_video_info(
video_info.update(**get_audio_info(video_path)) video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided # Add additional encoder configuration if provided
if camera_encoder is not None: if video_encoder is not None:
for field_name, field_value in asdict(camera_encoder).items(): for field_name, field_value in asdict(video_encoder).items():
# vcodec is already populated from the video stream # vcodec is already populated from the video stream
if field_name == "vcodec": if field_name == "vcodec":
continue continue
video_info.setdefault(f"video.{field_name}", field_value) video_info.setdefault(f"video.{field_name}", field_value)
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
return video_info return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float: def get_video_duration_in_s(video_path: Path | str) -> float:
""" """
Get the duration of a video file in seconds using PyAV. Get the duration of a video file in seconds using PyAV.
@@ -1202,10 +1279,13 @@ class VideoEncodingManager:
img_dir = self.dataset.root / "images" img_dir = self.dataset.root / "images"
if img_dir.exists(): if img_dir.exists():
png_files = list(img_dir.rglob("*.png")) png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0: tiff_files = list(img_dir.rglob("*.tiff"))
if len(png_files) == 0 and len(tiff_files) == 0:
shutil.rmtree(img_dir) shutil.rmtree(img_dir)
logger.debug("Cleaned up empty images directory") logger.debug("Cleaned up empty images directory")
else: else:
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") logger.debug(
f"Images directory is not empty, containing {len(png_files)} PNG and {len(tiff_files)} TIFF files"
)
return False # Don't suppress the original exception return False # Don't suppress the original exception
+2 -1
View File
@@ -126,7 +126,8 @@ def prepare_observation_for_inference(
for name in observation: for name in observation:
observation[name] = torch.from_numpy(observation[name]) observation[name] = torch.from_numpy(observation[name])
if "image" in name: if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255 if observation[name].dtype == torch.uint8:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device) observation[name] = observation[name].to(device)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot from ..robot import Robot
@@ -27,7 +28,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiOpenArmFollower(Robot): class BiOpenArmFollower(BimanualMixin, Robot):
""" """
Bimanual OpenArm Follower Arms Bimanual OpenArm Follower Arms
""" """
@@ -39,15 +40,17 @@ class BiOpenArmFollower(Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are distributed evenly: each arm's OpenArmFollower # Top-level cameras are opened by `left_arm` for convenience, but their
# will only open the cameras assigned to it. Per-arm cameras are used # keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
# as fallback when top-level cameras are empty. self._top_level_cam_keys = set(config.cameras)
if config.cameras: _collisions = self._top_level_cam_keys & set(
left_cameras = config.cameras config.left_arm_config.cameras
right_cameras = {} ) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
else: if _collisions:
left_cameras = config.left_arm_config.cameras raise ValueError(
right_cameras = config.right_arm_config.cameras f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = OpenArmFollowerConfig( left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
@@ -56,7 +59,7 @@ class BiOpenArmFollower(Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque, use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_cameras, cameras=left_arm_cameras,
side=config.left_arm_config.side, side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface, can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd, use_can_fd=config.left_arm_config.use_can_fd,
@@ -75,7 +78,7 @@ class BiOpenArmFollower(Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque, use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target, max_relative_target=config.right_arm_config.max_relative_target,
cameras=right_cameras, cameras=config.right_arm_config.cameras,
side=config.right_arm_config.side, side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface, can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd, use_can_fd=config.right_arm_config.use_can_fd,
@@ -95,22 +98,19 @@ class BiOpenArmFollower(Robot):
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return { return {
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()}, **{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()}, **{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
} }
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base", out: dict[str, tuple] = {}
# "right_wrist"), so we merge them directly — unlike motors which need the for k, v in self.left_arm._cameras_ft.items():
# left_/right_ prefix to disambiguate identical per-arm joint names. out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft} for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -120,27 +120,6 @@ class BiOpenArmFollower(Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors." "Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -148,21 +127,15 @@ class BiOpenArmFollower(Robot):
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict = {} obs_dict: RobotObservation = {}
# Camera keys that should NOT get the arm prefix (they already have unique names) # Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
left_cam_keys = set(self.left_arm.cameras.keys()) for key, value in self.left_arm.get_observation().items():
right_cam_keys = set(self.right_arm.cameras.keys()) obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering # Add "right_" prefix
# and the dataset feature names recorded during data collection. for key, value in self.right_arm.get_observation().items():
right_obs = self.right_arm.get_observation() obs_dict[f"right_{key}"] = value
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict return obs_dict
@@ -189,9 +162,4 @@ class BiOpenArmFollower(Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_right, **prefixed_sent_action_left} return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,5 +32,7 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras shared across both arms. # Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot from ..robot import Robot
@@ -27,7 +28,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiRebotB601Follower(Robot): class BiRebotB601Follower(BimanualMixin, Robot):
"""Bimanual Seeed Studio reBot B601-DM follower. """Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -41,6 +42,18 @@ class BiRebotB601Follower(Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig( left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir, calibration_dir=config.calibration_dir,
@@ -49,7 +62,7 @@ class BiRebotB601Follower(Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud, dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras, cameras=left_arm_cameras,
motor_can_ids=config.left_arm_config.motor_can_ids, motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity, pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio, gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -86,10 +99,12 @@ class BiRebotB601Follower(Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
return { out: dict[str, tuple] = {}
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()}, for k, v in self.left_arm._cameras_ft.items():
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()}, out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
} for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -99,32 +114,13 @@ class BiRebotB601Follower(Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict = {} obs_dict: RobotObservation = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()}) for k, v in self.left_arm.get_observation().items():
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()}) obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
return obs_dict return obs_dict
@check_if_not_connected @check_if_not_connected
@@ -143,8 +139,3 @@ class BiRebotB601Follower(Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()}, **{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()}, **{f"right_{k}": v for k, v in sent_action_right.items()},
} }
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -27,3 +29,8 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..robot import Robot from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -27,7 +28,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiSOFollower(Robot): class BiSOFollower(BimanualMixin, Robot):
""" """
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio [Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
""" """
@@ -39,6 +40,18 @@ class BiSOFollower(Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig( left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir, calibration_dir=config.calibration_dir,
@@ -46,7 +59,7 @@ class BiSOFollower(Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees, use_degrees=config.left_arm_config.use_degrees,
cameras=config.left_arm_config.cameras, cameras=left_arm_cameras,
) )
right_arm_config = SOFollowerRobotConfig( right_arm_config = SOFollowerRobotConfig(
@@ -77,13 +90,12 @@ class BiSOFollower(Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
left_arm_cameras_ft = self.left_arm._cameras_ft out: dict[str, tuple] = {}
right_arm_cameras_ft = self.right_arm._cameras_ft for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
return { for k, v in self.right_arm._cameras_ft.items():
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()}, out[f"right_{k}"] = v
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()}, return out
}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -93,42 +105,21 @@ class BiSOFollower(Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
self.left_arm.setup_motors() self.left_arm.setup_motors()
self.right_arm.setup_motors() self.right_arm.setup_motors()
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict = {} obs_dict: RobotObservation = {}
# Add "left_" prefix # Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
left_obs = self.left_arm.get_observation() for key, value in self.left_arm.get_observation().items():
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()}) obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "right_" prefix # Add "right_" prefix
right_obs = self.right_arm.get_observation() for key, value in self.right_arm.get_observation().items():
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) obs_dict[f"right_{key}"] = value
return obs_dict return obs_dict
@@ -151,8 +142,3 @@ class BiSOFollower(Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right} return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
from ..so_follower import SOFollowerConfig from ..so_follower import SOFollowerConfig
@@ -27,3 +29,8 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
+12 -3
View File
@@ -68,9 +68,12 @@ class SOFollower(Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
return { features: dict[str, tuple] = {}
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras for cam in self.cameras:
} features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
if getattr(self.cameras[cam], "use_depth", False):
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
return features
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -190,6 +193,12 @@ class SOFollower(Robot):
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
if getattr(cam, "use_depth", False):
start = time.perf_counter()
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
return obs_dict return obs_dict
@check_if_not_connected @check_if_not_connected
+2
View File
@@ -333,6 +333,7 @@ def build_rollout_context(
root=cfg.dataset.root, root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder, camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding, streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads, encoder_threads=cfg.dataset.encoder_threads,
@@ -368,6 +369,7 @@ def build_rollout_context(
* len(robot.cameras if hasattr(robot, "cameras") else []), * len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder, camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding, streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads, encoder_threads=cfg.dataset.encoder_threads,
+1
View File
@@ -54,6 +54,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
+31 -2
View File
@@ -133,6 +133,15 @@ Convert image dataset to video format and save locally:
--new_root /path/to/output/pusht_video \ --new_root /path/to/output/pusht_video \
--operation.type convert_image_to_video --operation.type convert_image_to_video
Convert image dataset (with depth maps) to video format, customizing the depth encoder:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_root /path/to/output/pusht_video \
--operation.type convert_image_to_video \
--operation.depth_encoder.depth_min 0.01 \
--operation.depth_encoder.depth_max 10.0 \
--operation.depth_encoder.use_log true
Convert image dataset to video format and save with new repo_id: Convert image dataset to video format and save with new repo_id:
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
@@ -211,6 +220,13 @@ Re-encode videos in-place (overwrites original dataset):
--operation.camera_encoder.vcodec h264 \ --operation.camera_encoder.vcodec h264 \
--operation.overwrite true --operation.overwrite true
Re-encode both RGB and depth videos in a dataset (depth quantization params are preserved):
lerobot-edit-dataset \
--repo_id lerobot/pusht_depth \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec libx264 \
--operation.depth_encoder.vcodec ffv1
Using JSON config file: Using JSON config file:
lerobot-edit-dataset \ lerobot-edit-dataset \
--config_path path/to/edit_config.json --config_path path/to/edit_config.json
@@ -225,7 +241,13 @@ from pathlib import Path
import draccus import draccus
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, parser from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
parser,
)
from lerobot.datasets import ( from lerobot.datasets import (
LeRobotDataset, LeRobotDataset,
convert_image_to_video_dataset, convert_image_to_video_dataset,
@@ -288,6 +310,7 @@ class ModifyTasksConfig(OperationConfig):
class ConvertImageToVideoConfig(OperationConfig): class ConvertImageToVideoConfig(OperationConfig):
output_dir: str | None = None output_dir: str | None = None
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults) camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
episode_indices: list[int] | None = None episode_indices: list[int] | None = None
num_workers: int = 4 num_workers: int = 4
max_episodes_per_batch: int | None = None max_episodes_per_batch: int | None = None
@@ -309,6 +332,7 @@ class RecomputeStatsConfig(OperationConfig):
@dataclass @dataclass
class ReencodeVideosConfig(OperationConfig): class ReencodeVideosConfig(OperationConfig):
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults) camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
num_workers: int = 0 num_workers: int = 0
encoder_threads: int | None = None encoder_threads: int | None = None
overwrite: bool = False overwrite: bool = False
@@ -602,6 +626,7 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
output_dir=output_dir, output_dir=output_dir,
repo_id=output_repo_id, repo_id=output_repo_id,
camera_encoder=getattr(cfg.operation, "camera_encoder", None) or camera_encoder_defaults(), camera_encoder=getattr(cfg.operation, "camera_encoder", None) or camera_encoder_defaults(),
depth_encoder=getattr(cfg.operation, "depth_encoder", None) or depth_encoder_defaults(),
episode_indices=getattr(cfg.operation, "episode_indices", None), episode_indices=getattr(cfg.operation, "episode_indices", None),
num_workers=getattr(cfg.operation, "num_workers", 4), num_workers=getattr(cfg.operation, "num_workers", 4),
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None), max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
@@ -719,10 +744,14 @@ def handle_reencode_videos(cfg: EditDatasetConfig) -> None:
shutil.copytree(input_root, output_root) shutil.copytree(input_root, output_root)
dataset = LeRobotDataset(output_repo_id, root=output_root) dataset = LeRobotDataset(output_repo_id, root=output_root)
logging.info(f"Re-encoding videos in {output_repo_id} with {cfg.operation.camera_encoder}") logging.info(
f"Re-encoding videos in {output_repo_id} with RGB encoder {cfg.operation.camera_encoder} "
f"and depth encoder {cfg.operation.depth_encoder}"
)
reencode_dataset( reencode_dataset(
dataset, dataset,
camera_encoder=cfg.operation.camera_encoder, camera_encoder=cfg.operation.camera_encoder,
depth_encoder=cfg.operation.depth_encoder,
encoder_threads=cfg.operation.encoder_threads, encoder_threads=cfg.operation.encoder_threads,
num_workers=cfg.operation.num_workers, num_workers=cfg.operation.num_workers,
) )
@@ -57,6 +57,7 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
gamepad, gamepad,
+3
View File
@@ -137,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
@@ -403,6 +404,7 @@ def record(
root=cfg.dataset.root, root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder, camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
encoder_threads=cfg.dataset.encoder_threads, encoder_threads=cfg.dataset.encoder_threads,
streaming_encoding=cfg.dataset.streaming_encoding, streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
@@ -432,6 +434,7 @@ def record(
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder, camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
encoder_threads=cfg.dataset.encoder_threads, encoder_threads=cfg.dataset.encoder_threads,
streaming_encoding=cfg.dataset.streaming_encoding, streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
+1
View File
@@ -174,6 +174,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
@@ -41,6 +41,7 @@ from lerobot.robots import ( # noqa: F401
) )
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
koch_leader, koch_leader,
@@ -89,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
gamepad, gamepad,
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
@@ -27,7 +28,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiOpenArmLeader(Teleoperator): class BiOpenArmLeader(BimanualMixin, Teleoperator):
""" """
Bimanual OpenArm Leader Arms Bimanual OpenArm Leader Arms
""" """
@@ -86,27 +87,6 @@ class BiOpenArmLeader(Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors." "Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -129,8 +109,3 @@ class BiOpenArmLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback # TODO: Implement force feedback
raise NotImplementedError raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader") @TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass @dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig): class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Follower robots.""" """Configuration class for Bi OpenArm Leader teleoperators."""
left_arm_config: OpenArmLeaderConfigBase left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase right_arm_config: OpenArmLeaderConfigBase
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .bi_rebot_102_leader import BiRebotArm102Leader from .bi_rebot_102_leader import BiRebot102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"] __all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
@@ -18,16 +18,17 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiRebotArm102Leader(Teleoperator): class BiRebot102Leader(BimanualMixin, Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader. """Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -35,10 +36,10 @@ class BiRebotArm102Leader(Teleoperator):
leader can teleoperate a bimanual reBot B601 follower. leader can teleoperate a bimanual reBot B601 follower.
""" """
config_class = BiRebotArm102LeaderConfig config_class = BiRebot102LeaderConfig
name = "bi_rebot_102_leader" name = "bi_rebot_102_leader"
def __init__(self, config: BiRebotArm102LeaderConfig): def __init__(self, config: BiRebot102LeaderConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
@@ -76,27 +77,6 @@ class BiRebotArm102Leader(Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> RobotAction:
action_dict = {} action_dict = {}
@@ -106,8 +86,3 @@ class BiRebotArm102Leader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.") raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader") @TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass @dataclass
class BiRebotArm102LeaderConfig(TeleoperatorConfig): class BiRebot102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator.""" """Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig left_arm_config: RebotArm102LeaderConfig
@@ -17,7 +17,9 @@
import logging import logging
from functools import cached_property from functools import cached_property
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
@@ -26,7 +28,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiSOLeader(Teleoperator): class BiSOLeader(BimanualMixin, Teleoperator):
""" """
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio [Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
""" """
@@ -67,33 +69,12 @@ class BiSOLeader(Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
self.left_arm.setup_motors() self.left_arm.setup_motors()
self.right_arm.setup_motors() self.right_arm.setup_motors()
@check_if_not_connected @check_if_not_connected
def get_action(self) -> dict[str, float]: def get_action(self) -> RobotAction:
action_dict = {} action_dict = {}
# Add "left_" prefix # Add "left_" prefix
@@ -109,8 +90,3 @@ class BiSOLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback # TODO: Implement force feedback
raise NotImplementedError raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
from .openarm_mini import OpenArmMini from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig"] __all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
@@ -19,12 +19,21 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass @dataclass
class OpenArmMiniConfig(TeleoperatorConfig): class OpenArmMiniConfigBase:
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms).""" """Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
port_right: str = "/dev/ttyUSB0" # Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port_left: str = "/dev/ttyUSB1" port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
pass
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout # Per-side motor direction flips applied during readout.
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"] SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"] "left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Leader joint 6 maps to follower joint 7 and vice versa # Leader joint 6 follower joint 7 (symmetric — its own inverse).
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"} JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65 GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator): class OpenArmMini(Teleoperator):
""" """OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos. For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
""" """
config_class = OpenArmMiniConfig config_class = OpenArmMiniConfig
@@ -56,9 +56,12 @@ class OpenArmMini(Teleoperator):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES norm_mode_body = MotorNormMode.DEGREES
motors = {
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body), "joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body), "joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body), "joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -69,46 +72,15 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
} }
motors_left = { self.bus = FeetechMotorsBus(
"joint_1": Motor(1, "sts3215", norm_mode_body), port=self.config.port,
"joint_2": Motor(2, "sts3215", norm_mode_body), motors=motors,
"joint_3": Motor(3, "sts3215", norm_mode_body), calibration=self.calibration,
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
) )
@property @property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
# Right first, then left — matches the robot (BiOpenArmFollower) ordering return {f"{motor}.pos": float for motor in self.bus.motors}
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property @property
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
@@ -116,14 +88,12 @@ class OpenArmMini(Teleoperator):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.bus_right.is_connected and self.bus_left.is_connected return self.bus.is_connected
@check_if_already_connected @check_if_already_connected
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting right arm on {self.config.port_right}...") logger.info(f"Connecting arm on {self.config.port}...")
self.bus_right.connect() self.bus.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate: if calibrate:
self.calibrate() self.calibrate()
@@ -133,14 +103,14 @@ class OpenArmMini(Teleoperator):
@property @property
def is_calibrated(self) -> bool: def is_calibrated(self) -> bool:
return self.bus_right.is_calibrated and self.bus_left.is_calibrated return self.bus.is_calibrated
def calibrate(self) -> None: def calibrate(self) -> None:
""" """
Run calibration procedure for OpenArm Mini. Run calibration procedure for a single OpenArm Mini arm.
1. Disable torque 1. Disable torque
2. Ask user to position arms in hanging position with grippers closed 2. Ask user to position arm in hanging position with gripper closed
3. Set this as zero position via half-turn homing 3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions) 4. Interactive gripper calibration (open/close positions)
5. Save calibration 5. Save calibration
@@ -152,70 +122,51 @@ class OpenArmMini(Teleoperator):
) )
if user_input.strip().lower() != "c": if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}") logger.info(f"Using existing calibration for {self.id}")
cal_right = { self.bus.write_calibration(self.calibration)
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return return
logger.info(f"\nRunning calibration for {self}") logger.info(f"\nRunning calibration for {self}")
self._calibrate_arm("right", self.bus_right) self.bus.disable_torque()
self._calibrate_arm("left", self.bus_left)
self._save_calibration() logger.info("Setting Phase to 12 for all motors...")
print(f"\nCalibration complete and saved to {self.calibration_fpath}") for motor in self.bus.motors:
self.bus.write("Phase", motor, 12)
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None: for motor in self.bus.motors:
"""Calibrate a single arm with Feetech motors.""" self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input( input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n" "\nCalibration: Zero Position\n"
"Position the arm in the following configuration:\n" "Position the arm in the following configuration:\n"
" - Arm hanging straight down\n" " - Arm hanging straight down\n"
" - Gripper closed\n" " - Gripper closed\n"
"Press ENTER when ready..." "Press ENTER when ready..."
) )
homing_offsets = bus.set_half_turn_homings() homing_offsets = self.bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.") logger.info("Arm zero position set.")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n") print("\nSetting motor ranges\n")
if self.calibration is None: if self.calibration is None:
self.calibration = {} self.calibration = {}
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model] motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
max_res = motor_resolution - 1 max_res = motor_resolution - 1
for motor_name, motor in bus.motors.items(): for motor_name, motor in self.bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper": if motor_name == "gripper":
input( input(
f"\nGripper Calibration ({arm_name.upper()} arm)\n" "\nGripper Calibration\n"
f"Step 1: CLOSE the gripper fully\n" "Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..." "Press ENTER when gripper is closed..."
) )
closed_pos = bus.read("Present_Position", motor_name, normalize=False) closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}") logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...") input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = bus.read("Present_Position", motor_name, normalize=False) open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}") logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos: if closed_pos < open_pos:
@@ -228,16 +179,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1 drive_mode = 1
logger.info( logger.info(
f" {prefixed_name}: range set to [{range_min}, {range_max}] " f" {motor_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})" f"(0=closed, 100=open, drive_mode={drive_mode})"
) )
else: else:
range_min = 0 range_min = 0
range_max = max_res range_max = max_res
drive_mode = 0 drive_mode = 0
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)") logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[prefixed_name] = MotorCalibration( self.calibration[motor_name] = MotorCalibration(
id=motor.id, id=motor.id,
drive_mode=drive_mode, drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name], homing_offset=homing_offsets[motor_name],
@@ -245,108 +196,68 @@ class OpenArmMini(Teleoperator):
range_max=range_max, range_max=range_max,
) )
cal_for_bus = { self.bus.write_calibration(self.calibration)
k.replace(f"{arm_name}_", ""): v self._save_calibration()
for k, v in self.calibration.items() print(f"\nCalibration complete and saved to {self.calibration_fpath}")
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None: def configure(self) -> None:
self.bus_right.disable_torque() self.bus.disable_torque()
self.bus_right.configure_motors() self.bus.configure_motors()
for motor in self.bus_right.motors: for motor in self.bus.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value) self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None: def setup_motors(self) -> None:
print("\nSetting up RIGHT arm motors...") for motor in reversed(self.bus.motors):
for motor in reversed(self.bus_right.motors): input(f"Connect the controller board to the '{motor}' motor only and press enter.")
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.") self.bus.setup_motor(motor)
self.bus_right.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> RobotAction:
"""Get current action from both arms (read positions from all motors).""" """Get current action (read positions from all motors)."""
start = time.perf_counter() start = time.perf_counter()
right_positions = self.bus_right.sync_read("Present_Position") positions = self.bus.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa. # Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {} action: dict[str, Any] = {}
for motor, val in right_positions.items(): for motor, val in positions.items():
target = JOINT_REMAP.get(motor, motor) target = JOINT_REMAP.get(motor, motor)
if motor == "gripper": if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65° # Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else: else:
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms") logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action return action
def enable_torque(self) -> None: def enable_torque(self) -> None:
"""Enable torque on both arms for position control.""" self.bus.enable_torque()
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None: def disable_torque(self) -> None:
"""Disable torque on both arms for free movement.""" self.bus.disable_torque()
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None: def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic).""" """Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
right_goals: dict[str, float] = {} goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items(): for key, val in positions.items():
if not key.endswith(".pos"): if not key.endswith(".pos"):
continue continue
motor_name = key.removesuffix(".pos") base = key.removesuffix(".pos")
if motor_name.startswith("right_"): # JOINT_REMAP is symmetric (its own inverse).
base = motor_name.removeprefix("right_") target = JOINT_REMAP.get(base, base)
# Reverse remap: follower joint_7 → leader joint_6 and vice versa if base == "gripper":
target = JOINT_REMAP_REVERSE.get(base, base) # Convert robot degrees to teleop 0-100: 0°→0, -65°→100
if base == "gripper": goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100 else:
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES # Un-flip using the ORIGINAL motor name (target = leader motor)
else: goals[target] = -val if target in self._motors_to_flip else val
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if right_goals: if goals:
self.bus_right.sync_write("Goal_Position", right_goals) self.bus.sync_write("Goal_Position", goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
@check_if_not_connected @check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -354,6 +265,5 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected @check_if_not_connected
def disconnect(self) -> None: def disconnect(self) -> None:
self.bus_right.disconnect() self.bus.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")
+6 -2
View File
@@ -99,14 +99,18 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini from .openarm_mini import OpenArmMini
return OpenArmMini(config) return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader": elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config) return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader": elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebotArm102Leader from .bi_rebot_102_leader import BiRebot102Leader
return BiRebotArm102Leader(config) return BiRebot102Leader(config)
else: else:
try: try:
return cast("Teleoperator", make_device_from_device_class(config)) return cast("Teleoperator", make_device_from_device_class(config))
+63
View File
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
+23 -12
View File
@@ -51,7 +51,9 @@ def hw_to_dataset_features(
This function takes a dictionary describing hardware outputs (like joint states This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature or camera image shapes) and formats it into the standard LeRobot feature
specification. specification. Single-channel cameras (shape ``(H, W, 1)``) are flagged as depth
maps via ``info["is_depth_map"] = True``; three-channel cameras ``(H, W, 3)`` are
treated as RGB.
Args: Args:
hw_features (dict): Dictionary mapping feature names to their type (float for hw_features (dict): Dictionary mapping feature names to their type (float for
@@ -61,7 +63,7 @@ def hw_to_dataset_features(
use_video (bool): If True, image features are marked as "video", otherwise "image". use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns: Returns:
dict: A LeRobot features dictionary. dict: A LeRobot features dictionary. Depth cameras carry ``info["is_depth_map"] = True``.
""" """
features = {} features = {}
joint_fts = { joint_fts = {
@@ -69,6 +71,7 @@ def hw_to_dataset_features(
for key, ftype in hw_features.items() for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
} }
# TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION: if joint_fts and prefix == ACTION:
@@ -86,11 +89,19 @@ def hw_to_dataset_features(
} }
for key, shape in cam_fts.items(): for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = { dtype = "video" if use_video else "image"
"dtype": "video" if use_video else "image", if len(shape) == 3 and shape[2] in (1, 3):
"shape": shape, features[f"{prefix}.images.{key}"] = {
"names": ["height", "width", "channels"], "dtype": dtype,
} "shape": shape,
"names": ["height", "width", "channels"],
"info": {"is_depth_map": shape[2] == 1},
}
else:
raise ValueError(
f"Camera feature '{key}' has shape {shape}. "
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
)
_validate_feature_names(features) _validate_feature_names(features)
return features return features
@@ -149,11 +160,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
type = FeatureType.VISUAL type = FeatureType.VISUAL
if len(shape) != 3: if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
else:
names = ft["names"] names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1]) shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE: elif key == OBS_ENV_STATE:
type = FeatureType.ENV type = FeatureType.ENV
elif key.startswith(OBS_STR): elif key.startswith(OBS_STR):
+8 -1
View File
@@ -107,7 +107,14 @@ def log_rerun_data(
for i, vi in enumerate(arr): for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalars(float(vi))) rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
else: else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr) if arr.shape[-1] == 1:
img_entity = (
rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis).compress()
if compress_images
else rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True) rr.log(key, entity=img_entity, static=True)
if action: if action:
+2 -2
View File
@@ -208,14 +208,14 @@ def test_episode_clip_path_trims_via_reencode_video(tmp_path: Path, monkeypatch)
def fake_reencode( def fake_reencode(
input_video_path, input_video_path,
output_video_path, output_video_path,
camera_encoder=None, video_encoder=None,
overwrite=False, overwrite=False,
start_time_s=None, start_time_s=None,
end_time_s=None, end_time_s=None,
): ):
captured.update( captured.update(
src=Path(input_video_path), src=Path(input_video_path),
encoder=camera_encoder, encoder=video_encoder,
start_time_s=start_time_s, start_time_s=start_time_s,
end_time_s=end_time_s, end_time_s=end_time_s,
) )
+73
View File
@@ -28,6 +28,7 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])") pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402 import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402 from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24 assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None: def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!") atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant" assert atom["role"] == "assistant"
+129 -8
View File
@@ -29,7 +29,30 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from lerobot.datasets.aggregate import aggregate_datasets from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.feature_utils import features_equal_for_merge from lerobot.datasets.feature_utils import features_equal_for_merge
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID from tests.fixtures.constants import (
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
DUMMY_REPO_ID,
)
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames): def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
@@ -191,6 +214,26 @@ def assert_dataset_iteration_works(aggr_ds):
pass pass
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
"""Test that depth keys are correctly preserved after aggregation.
Ensures that the ``is_depth_map`` marker on visual features survives
aggregation, so that downstream consumers (e.g. the dataset reader's
depth decoding path) keep working on the merged dataset.
"""
expected_depth_keys = set(ds_0.meta.depth_keys)
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
"Source datasets disagree on depth_keys; test setup is inconsistent"
)
actual_depth_keys = set(aggr_ds.meta.depth_keys)
assert actual_depth_keys == expected_depth_keys, (
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
)
for key in expected_depth_keys:
info = aggr_ds.meta.info.features[key].get("info") or {}
assert info.get("is_depth_map") is True, f"Depth marker lost on feature {key!r} after aggregation"
def assert_video_timestamps_within_bounds(aggr_ds): def assert_video_timestamps_within_bounds(aggr_ds):
"""Test that all video timestamps are within valid bounds for their respective video files. """Test that all video timestamps are within valid bounds for their respective video files.
@@ -240,7 +283,11 @@ def assert_video_timestamps_within_bounds(aggr_ds):
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters.""" """Test basic aggregation functionality with standard parameters.
Source datasets include both RGB and depth video features so the same
aggregation flow is exercised on the ``is_depth_map`` branch.
"""
ds_0_num_frames = 400 ds_0_num_frames = 400
ds_1_num_frames = 800 ds_1_num_frames = 800
ds_0_num_episodes = 10 ds_0_num_episodes = 10
@@ -252,14 +299,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
repo_id=f"{DUMMY_REPO_ID}_0", repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=ds_0_num_episodes, total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames, total_frames=ds_0_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
ds_1 = lerobot_dataset_factory( ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1", root=tmp_path / "test_1",
repo_id=f"{DUMMY_REPO_ID}_1", repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=ds_1_num_episodes, total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames, total_frames=ds_1_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
# Confirm depth was actually wired into the source datasets so the
# rest of the assertions exercise the depth aggregation path.
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
aggregate_datasets( aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id], repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root], roots=[ds_0.root, ds_1.root],
@@ -286,6 +340,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds) assert_video_timestamps_within_bounds(aggr_ds)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds) assert_dataset_iteration_works(aggr_ds)
@@ -403,7 +458,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
"""Test aggregation with small file size limits to force file rotation/sharding.""" """Test aggregation with small file size limits to force file rotation/sharding.
Depth video features are included to verify that file rotation/concat
correctly handles depth-marked features alongside regular RGB ones.
"""
ds_0_num_episodes = ds_1_num_episodes = 10 ds_0_num_episodes = ds_1_num_episodes = 10
ds_0_num_frames = ds_1_num_frames = 400 ds_0_num_frames = ds_1_num_frames = 400
@@ -412,14 +471,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
repo_id=f"{DUMMY_REPO_ID}_small_0", repo_id=f"{DUMMY_REPO_ID}_small_0",
total_episodes=ds_0_num_episodes, total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames, total_frames=ds_0_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
ds_1 = lerobot_dataset_factory( ds_1 = lerobot_dataset_factory(
root=tmp_path / "small_1", root=tmp_path / "small_1",
repo_id=f"{DUMMY_REPO_ID}_small_1", repo_id=f"{DUMMY_REPO_ID}_small_1",
total_episodes=ds_1_num_episodes, total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames, total_frames=ds_1_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
# Use the new configurable parameters to force file rotation # Use the new configurable parameters to force file rotation
aggregate_datasets( aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id], repo_ids=[ds_0.repo_id, ds_1.repo_id],
@@ -450,6 +514,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds) assert_video_timestamps_within_bounds(aggr_ds)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds) assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits # Check that multiple files were actually created due to small size limits
@@ -469,7 +534,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
"""Regression test for video timestamp bug when merging datasets. """Regression test for video timestamp bug when merging datasets.
This test specifically checks that video timestamps are correctly calculated This test specifically checks that video timestamps are correctly calculated
and accumulated when merging multiple datasets. and accumulated when merging multiple datasets. Depth video features are
included so depth timestamps are also covered by the regression.
""" """
datasets = [] datasets = []
for i in range(3): for i in range(3):
@@ -478,9 +544,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
repo_id=f"{DUMMY_REPO_ID}_regression_{i}", repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
total_episodes=2, total_episodes=2,
total_frames=100, total_frames=100,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
datasets.append(ds) datasets.append(ds)
for i, ds in enumerate(datasets):
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
aggregate_datasets( aggregate_datasets(
repo_ids=[ds.repo_id for ds in datasets], repo_ids=[ds.repo_id for ds in datasets],
roots=[ds.root for ds in datasets], roots=[ds.root for ds in datasets],
@@ -497,12 +567,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr") aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
assert_video_timestamps_within_bounds(aggr_ds) assert_video_timestamps_within_bounds(aggr_ds)
# Depth keys must survive the merge for the regression to cover the
# ``is_depth_map`` decoding branch.
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
depth_keys = set(aggr_ds.meta.depth_keys)
for i in range(len(aggr_ds)): for i in range(len(aggr_ds)):
item = aggr_ds[i] item = aggr_ds[i]
for key in aggr_ds.meta.video_keys: for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}" assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" # Depth frames are single-channel (1, H, W) after dequantization;
# standard RGB frames keep the 3-channel layout.
expected_channels = 1 if key in depth_keys else 3
assert item[key].shape[0] == expected_channels, (
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
)
def assert_image_schema_preserved(aggr_ds): def assert_image_schema_preserved(aggr_ds):
@@ -566,6 +645,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
) )
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema. """Test aggregation of image-based datasets preserves HuggingFace Image schema.
@@ -584,25 +698,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
ds_0_num_episodes = 2 ds_0_num_episodes = 2
ds_1_num_episodes = 3 ds_1_num_episodes = 3
# Create two image-based datasets (use_videos=False) # Create two image-based datasets (use_videos=False) with a mix of RGB
# and depth-marked cameras so the depth path is exercised in image mode.
ds_0 = lerobot_dataset_factory( ds_0 = lerobot_dataset_factory(
root=tmp_path / "image_0", root=tmp_path / "image_0",
repo_id=f"{DUMMY_REPO_ID}_image_0", repo_id=f"{DUMMY_REPO_ID}_image_0",
total_episodes=ds_0_num_episodes, total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames, total_frames=ds_0_num_frames,
use_videos=False, # Image-based dataset use_videos=False,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
ds_1 = lerobot_dataset_factory( ds_1 = lerobot_dataset_factory(
root=tmp_path / "image_1", root=tmp_path / "image_1",
repo_id=f"{DUMMY_REPO_ID}_image_1", repo_id=f"{DUMMY_REPO_ID}_image_1",
total_episodes=ds_1_num_episodes, total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames, total_frames=ds_1_num_frames,
use_videos=False, # Image-based dataset use_videos=False,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
) )
# Verify source datasets have image keys # Verify source datasets have image keys
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys" assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys" assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
# And that the depth marker actually made it onto an image feature.
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
# Aggregate the datasets # Aggregate the datasets
aggregate_datasets( aggregate_datasets(
@@ -637,6 +757,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
# Image-specific assertions # Image-specific assertions
assert_image_schema_preserved(aggr_ds) assert_image_schema_preserved(aggr_ds)
assert_image_frames_integrity(aggr_ds, ds_0, ds_1) assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
# Verify images can be accessed and have correct shape # Verify images can be accessed and have correct shape
sample_item = aggr_ds[0] sample_item = aggr_ds[0]
+45 -4
View File
@@ -59,11 +59,13 @@ def _make_dummy_stats(features: dict) -> dict:
stats = {} stats = {}
for key, ft in features.items(): for key, ft in features.items():
if ft["dtype"] in ("image", "video"): if ft["dtype"] in ("image", "video"):
channels = ft["shape"][-1]
stat_shape = (channels, 1, 1)
stats[key] = { stats[key] = {
"max": np.ones((3, 1, 1), dtype=np.float32), "max": np.ones(stat_shape, dtype=np.float32),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32), "mean": np.full(stat_shape, 0.5, dtype=np.float32),
"min": np.zeros((3, 1, 1), dtype=np.float32), "min": np.zeros(stat_shape, dtype=np.float32),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32), "std": np.full(stat_shape, 0.25, dtype=np.float32),
"count": np.array([5]), "count": np.array([5]),
} }
elif ft["dtype"] in ("float32", "float64", "int64"): elif ft["dtype"] in ("float32", "float64", "int64"):
@@ -142,6 +144,45 @@ def test_create_without_videos_has_no_video_path(tmp_path):
assert meta.video_keys == [] assert meta.video_keys == []
@pytest.mark.parametrize(
("marker_field", "marker_key"),
[
("info", "is_depth_map"),
("info", "video.is_depth_map"),
("video_info", "video.is_depth_map"),
],
ids=["info.is_depth_map", "info.video.is_depth_map_legacy", "video_info.video.is_depth_map_legacy"],
)
def test_depth_keys_property_filters_by_marker(tmp_path, marker_field, marker_key):
"""``depth_keys`` recognises the canonical and the two legacy marker variants."""
depth_feature = {
"dtype": "video",
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
marker_field: {marker_key: True},
}
features = {
**VIDEO_FEATURES,
"observation.images.laptop_depth": depth_feature,
}
meta = LeRobotDatasetMetadata.create(
repo_id="test/depth_keys",
fps=DEFAULT_FPS,
features=features,
root=tmp_path / f"depth_keys_{marker_field}_{marker_key.replace('.', '_')}",
)
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
assert meta.depth_keys == ["observation.images.laptop_depth"]
def test_depth_keys_empty_when_no_marker(tmp_path):
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
)
assert meta.depth_keys == []
def test_create_raises_on_existing_directory(tmp_path): def test_create_raises_on_existing_directory(tmp_path):
"""create() raises if root directory already exists.""" """create() raises if root directory already exists."""
root = tmp_path / "existing" root = tmp_path / "existing"
+128 -4
View File
@@ -24,7 +24,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs import VideoEncoderConfig from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
from lerobot.datasets.dataset_tools import ( from lerobot.datasets.dataset_tools import (
add_features, add_features,
convert_image_to_video_dataset, convert_image_to_video_dataset,
@@ -37,7 +37,9 @@ from lerobot.datasets.dataset_tools import (
split_dataset, split_dataset,
) )
from lerobot.datasets.io_utils import load_info from lerobot.datasets.io_utils import load_info
from tests.datasets.test_video_encoding import _add_frames, require_h264, require_libsvtav1 from tests.datasets.test_video_encoding import require_h264, require_hevc, require_libsvtav1
from tests.fixtures.constants import DUMMY_DEPTH_FEATURES, DUMMY_DEPTH_KEY
from tests.fixtures.dataset_factories import add_frames
@pytest.fixture @pytest.fixture
@@ -1332,9 +1334,131 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
@require_libsvtav1
@require_hevc
def test_convert_image_to_video_dataset_depth(tmp_path, empty_lerobot_dataset_factory):
"""Depth image features convert to depth videos using the depth encoder.
Mirrors :func:`test_convert_image_to_video_dataset` but with a small local
image dataset that mixes an RGB camera with a depth camera, so the
``depth_keys`` ``depth_encoder`` routing and ``is_depth_map`` preservation
are exercised end-to-end.
"""
features = {
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
"observation.images.cam": {
"dtype": "image",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
},
"observation.images.depth": {
"dtype": "image",
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": {"is_depth_map": True},
},
}
source_dataset = empty_lerobot_dataset_factory(
root=tmp_path / "img_ds",
features=features,
use_videos=False,
)
add_frames(source_dataset, num_frames=4)
source_dataset.save_episode()
source_dataset.finalize()
# Source is an image dataset with the depth marker on the depth camera.
assert len(source_dataset.meta.video_keys) == 0
assert "observation.images.depth" in source_dataset.meta.depth_keys
output_dir = tmp_path / "video_ds"
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
# Use non-default quantization params so the persisted metadata must
# come from the depth encoder (not RGB encoder defaults).
depth_encoder = DepthEncoderConfig(
vcodec="hevc",
pix_fmt="gray12le",
g=2,
crf=30,
depth_min=0.05,
depth_max=8.0,
shift=2.0,
use_log=False,
)
video_dataset = convert_image_to_video_dataset(
dataset=source_dataset,
output_dir=output_dir,
repo_id="dummy/depth_video",
camera_encoder=VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30),
depth_encoder=depth_encoder,
num_workers=1,
)
# Both cameras are now videos, and the depth marker survived the conversion.
assert "observation.images.cam" in video_dataset.meta.video_keys
assert "observation.images.depth" in video_dataset.meta.video_keys
assert "observation.images.depth" in video_dataset.meta.depth_keys
assert "observation.images.cam" not in video_dataset.meta.depth_keys
depth_path = video_dataset.root / video_dataset.meta.get_video_file_path(0, "observation.images.depth")
assert depth_path.exists(), f"Depth video file should exist: {depth_path}"
# The persisted depth-video metadata must carry the depth quantization params
# from the depth encoder (so frames dequantize correctly on read), and the RGB
# camera must not be marked as a depth map.
persisted_info = load_info(video_dataset.root)
depth_info = persisted_info.features["observation.images.depth"]["info"]
assert depth_info["is_depth_map"] is True
assert DepthEncoderConfig.from_video_info(depth_info) == depth_encoder
cam_info = persisted_info.features["observation.images.cam"]["info"]
assert cam_info.get("is_depth_map") is False
assert "video.codec" in cam_info
# ─── reencode_dataset ───────────────────────────────────────────────── # ─── reencode_dataset ─────────────────────────────────────────────────
@require_hevc
def test_reencode_dataset_depth_uses_depth_encoder(tmp_path, empty_lerobot_dataset_factory):
"""Depth videos are re-encoded with the depth encoder and keep their depth metadata.
Depth-focused companion to :func:`test_reencode_dataset_multi_key_multiprocessing`.
"""
initial_cfg = DepthEncoderConfig(vcodec="hevc", pix_fmt="gray12le", g=2, crf=30)
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds",
features=DUMMY_DEPTH_FEATURES,
use_videos=True,
depth_encoder=initial_cfg,
)
add_frames(dataset, num_frames=4)
dataset.save_episode()
dataset.finalize()
assert DUMMY_DEPTH_KEY in dataset.meta.depth_keys
target_cfg = DepthEncoderConfig(vcodec="hevc", pix_fmt="gray12le", g=6, crf=23)
result = reencode_dataset(dataset, depth_encoder=target_cfg, num_workers=0)
assert result is dataset
persisted_info = load_info(dataset.root)
depth_info = persisted_info.features[DUMMY_DEPTH_KEY].get("info", {})
# Re-encode applied the new codec parameters to the depth video ...
assert DepthEncoderConfig.from_video_info(depth_info) == target_cfg
# ... while preserving the depth marker.
assert depth_info["is_depth_map"] is True
@require_libsvtav1 @require_libsvtav1
@require_h264 @require_h264
def test_reencode_dataset_multi_key_multiprocessing( def test_reencode_dataset_multi_key_multiprocessing(
@@ -1350,9 +1474,9 @@ def test_reencode_dataset_multi_key_multiprocessing(
camera_encoder=initial_cfg, camera_encoder=initial_cfg,
) )
_add_frames(dataset, num_frames=4) add_frames(dataset, num_frames=4)
dataset.save_episode() dataset.save_episode()
_add_frames(dataset, num_frames=4) add_frames(dataset, num_frames=4)
dataset.save_episode() dataset.save_episode()
dataset.finalize() dataset.finalize()
+7 -7
View File
@@ -53,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
# ── Existing encode_video_worker tests ─────────────────────────────── # ── Existing encode_video_worker tests ───────────────────────────────
def test_encode_video_worker_forwards_camera_encoder(tmp_path): def test_encode_video_worker_forwards_video_encoder(tmp_path):
"""_encode_video_worker forwards camera_encoder to encode_video_frames.""" """_encode_video_worker forwards video_encoder to encode_video_frames."""
video_key = "observation.images.laptop" video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent img_dir = tmp_path / Path(fpath).parent
@@ -74,16 +74,16 @@ def test_encode_video_worker_forwards_camera_encoder(tmp_path):
0, 0,
tmp_path, tmp_path,
fps=30, fps=30,
camera_encoder=VideoEncoderConfig(vcodec="h264", preset=None), video_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
encoder_threads=4, encoder_threads=4,
) )
assert captured_kwargs["camera_encoder"].vcodec == "h264" assert captured_kwargs["video_encoder"].vcodec == "h264"
assert captured_kwargs["encoder_threads"] == 4 assert captured_kwargs["encoder_threads"] == 4
def test_encode_video_worker_default_camera_encoder(tmp_path): def test_encode_video_worker_default_video_encoder(tmp_path):
"""_encode_video_worker passes None camera_encoder which encode_video_frames defaults.""" """_encode_video_worker passes None video_encoder which encode_video_frames defaults."""
video_key = "observation.images.laptop" video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent img_dir = tmp_path / Path(fpath).parent
@@ -100,7 +100,7 @@ def test_encode_video_worker_default_camera_encoder(tmp_path):
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30) _encode_video_worker(video_key, 0, tmp_path, fps=30)
assert captured_kwargs["camera_encoder"] is None assert captured_kwargs["video_encoder"] is None
assert captured_kwargs["encoder_threads"] is None assert captured_kwargs["encoder_threads"] is None
+22 -2
View File
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
from lerobot.transforms import ImageTransforms, ImageTransformsConfig from lerobot.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from lerobot.utils.feature_utils import hw_to_dataset_features from lerobot.utils.feature_utils import hw_to_dataset_features
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
) )
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
)
dataset_copy = empty_lerobot_dataset_factory(
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
)
original_shape = dataset.meta.info.features["state"]["shape"]
dataset_copy.meta.info.features["state"]["shape"] = (999,)
assert dataset.meta.info.features["state"]["shape"] == original_shape
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@@ -1516,10 +1531,15 @@ def test_valid_video_codecs_constant():
assert "h264" in VALID_VIDEO_CODECS assert "h264" in VALID_VIDEO_CODECS
assert "hevc" in VALID_VIDEO_CODECS assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS assert "libsvtav1" in VALID_VIDEO_CODECS
assert "ffv1" in VALID_VIDEO_CODECS
assert "auto" in VALID_VIDEO_CODECS assert "auto" in VALID_VIDEO_CODECS
assert "h264_videotoolbox" in VALID_VIDEO_CODECS assert "h264_videotoolbox" in VALID_VIDEO_CODECS
assert "h264_nvenc" in VALID_VIDEO_CODECS assert "h264_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 10 assert "h264_vaapi" in VALID_VIDEO_CODECS
assert "h264_qsv" in VALID_VIDEO_CODECS
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
assert "hevc_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 11
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory): def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
+241
View File
@@ -0,0 +1,241 @@
"""Tests for the depth-integration feature.
Covers:
- ``depth_utils`` quantize/dequantize round-trips and backend agreement.
- Image-writer support for single-channel depth.
- Hardware-feature depth flag routing.
- Feature-to-file-format routing through the dataset writer.
Depth metadata detection on ``LeRobotDatasetMetadata.depth_keys`` lives in
``test_dataset_metadata.py``. Depth video encoding/decoding lives in
``test_video_encoding.py``.
"""
from pathlib import Path
import pytest
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
import av
import numpy as np
import PIL.Image
import torch
from lerobot.configs import DepthEncoderConfig
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
DUMMY_CHW,
DUMMY_DEPTH_CAMERA_FEATURES,
DUMMY_REPO_ID,
)
from tests.fixtures.dataset_factories import add_frames
_, H, W = DUMMY_CHW
def _depth_metres_ramp() -> np.ndarray:
"""Linearly-spaced float32 depth in metres covering the default range."""
return np.linspace(DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
# ── 1. Quantize / dequantize round-trips ──────────────────────────────
class TestQuantizeDequantize:
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
@pytest.mark.parametrize("use_log", [False, True])
@pytest.mark.parametrize("output_unit", ["m", "mm"])
@pytest.mark.parametrize("output_channel_last", [False, True])
def test_roundtrip(self, use_log, output_unit, output_channel_last):
"""quantize → dequantize recovers depth; layout and unit are honored."""
depth = _depth_metres_ramp()
quantized = quantize_depth(depth, use_log=use_log, video_backend=None)
recovered = dequantize_depth(
quantized,
use_log=use_log,
output_unit=output_unit,
output_tensor=False,
output_channel_last=output_channel_last,
)
expected_shape = (H, W, 1) if output_channel_last else (1, H, W)
assert recovered.shape == expected_shape
recovered_m = recovered.astype(np.float32)
if output_unit == "mm":
recovered_m = recovered_m / 1000.0
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
if use_log:
# Log mode: tighter near-range error than far-range (the whole point).
near = depth < 1.0
far = depth > 8.0
err_near = np.abs(recovered_2d[near] - depth[near])
err_far = np.abs(recovered_2d[far] - depth[far])
assert err_near.mean() < err_far.mean()
else:
# Linear mode: bounded by quant step + 1 mm of unit-conversion rounding.
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
@pytest.mark.parametrize("use_log", [False, True])
@pytest.mark.parametrize("output_unit", ["m", "mm"])
def test_numpy_torch_agree(self, use_log, output_unit):
"""Batched torch path produces the same values as the numpy path."""
batch_size = 3
per_frame = np.linspace(0, DEPTH_QMAX, H * W, dtype=np.uint16).reshape(H, W)
batch_np = np.broadcast_to(per_frame[None, None, ...], (batch_size, 1, H, W)).copy()
batch_t = torch.from_numpy(batch_np.astype(np.int32)) # torch.uint16 support is patchy.
ref = dequantize_depth(batch_np, use_log=use_log, output_unit=output_unit, output_tensor=False)
out = dequantize_depth(batch_t, use_log=use_log, output_unit=output_unit, output_tensor=True)
assert isinstance(out, torch.Tensor)
assert out.shape == (batch_size, 1, H, W)
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
atol = 1e-5 if output_unit == "m" else 1.0
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
@pytest.mark.parametrize(
"input_shape,output_shape",
[
((H, W), (1, H, W)),
((1, H, W), (1, H, W)),
((H, W, 1), (1, H, W)),
((3, 1, H, W), (3, 1, H, W)),
((3, H, W, 1), (3, 1, H, W)),
],
)
def test_input_layouts_accepted(self, input_shape, output_shape):
"""All documented input layouts decode to the channel-first default."""
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
assert out.shape == output_shape
def test_pyav_frame_roundtrip(self):
"""quantize → av.VideoFrame → dequantize works."""
depth = _depth_metres_ramp()
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
assert isinstance(frame, av.VideoFrame)
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
assert recovered.shape == (1, H, W)
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
np.testing.assert_allclose(recovered[0], depth, atol=tol)
def test_invalid_log_params_raises(self):
with pytest.raises(ValueError, match=r"depth_min \+ shift must be positive"):
quantize_depth(_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
# ── 2. Image writer depth support ─────────────────────────────────────
class TestImageWriterDepth:
"""``image_array_to_pil_image`` and ``write_image`` for depth maps."""
@pytest.mark.parametrize("dtype,expected_mode", [(np.uint16, "I;16"), (np.float32, "F")])
@pytest.mark.parametrize("shape", [(H, W), (H, W, 1), (1, H, W)])
def test_pil_depth_modes_and_squeeze(self, dtype, expected_mode, shape):
"""Single-channel depth converts to PIL with the right mode and (W, H) size."""
arr = np.zeros(shape, dtype=dtype)
img = image_array_to_pil_image(arr)
assert img.mode == expected_mode
assert img.size == (W, H)
def test_write_image_tiff_roundtrip(self, tmp_path):
"""uint16 depth round-trips through .tiff."""
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
fpath = tmp_path / "depth.tiff"
write_image(arr, fpath)
with PIL.Image.open(fpath) as loaded:
recovered = np.array(loaded)
np.testing.assert_array_equal(recovered, arr)
# ── 3. Hardware-feature → depth flag ──────────────────────────────────
class TestHwToDatasetFeaturesDepth:
"""``hw_to_dataset_features`` flags single-channel cameras as depth."""
@pytest.mark.parametrize("channels,is_depth", [(1, True), (3, False)])
def test_depth_marker_by_channels(self, channels, is_depth):
from lerobot.utils.feature_utils import hw_to_dataset_features
features = hw_to_dataset_features({"cam": (480, 640, channels)}, prefix="observation")
assert features["observation.images.cam"]["info"]["is_depth_map"] is is_depth
def test_invalid_channel_count_raises(self):
from lerobot.utils.feature_utils import hw_to_dataset_features
with pytest.raises(ValueError, match="Expected a 3-tuple"):
hw_to_dataset_features({"cam": (480, 640, 2)}, prefix="observation")
# ── 4. Feature-to-file-format routing ────────────────────────────────
# Keys derived from DUMMY_CAMERA_FEATURES_WITH_DEPTH; pick one RGB and the depth camera.
RGB_KEY = next(iter(DUMMY_CAMERA_FEATURES))
DEPTH_KEY = next(iter(DUMMY_DEPTH_CAMERA_FEATURES))
class TestFeatureFileRouting:
"""Depth vs RGB features route to the correct file format."""
NUM_FRAMES = 5
def test_image_mode_depth_tiff_rgb_png(self, tmp_path, features_factory):
"""Without video encoding: depth → .tiff, RGB → .png."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=False)
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=tmp_path / "ds",
use_videos=False,
)
add_frames(dataset, num_frames=self.NUM_FRAMES)
buf = dataset.writer.episode_buffer
assert all(Path(p).suffix == ".tiff" for p in buf[DEPTH_KEY])
assert all(Path(p).suffix == ".png" for p in buf[RGB_KEY])
dataset.save_episode()
dataset.finalize()
def test_video_mode_depth_uses_depth_encoder(self, tmp_path, features_factory):
"""With streaming video encoding: depth → DepthEncoderConfig, RGB does not."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=True)
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=tmp_path / "ds",
use_videos=True,
streaming_encoding=True,
)
add_frames(dataset, num_frames=self.NUM_FRAMES)
encoder = dataset.writer._streaming_encoder
assert encoder is not None
assert isinstance(encoder._threads[DEPTH_KEY].video_encoder, DepthEncoderConfig)
assert not isinstance(encoder._threads[RGB_KEY].video_encoder, DepthEncoderConfig)
dataset.save_episode()
dataset.finalize()
+1 -1
View File
@@ -94,7 +94,7 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
def test_image_array_to_pil_image_single_channel(img_array_factory): def test_image_array_to_pil_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1) img_array = img_array_factory(channels=1)
with pytest.raises(NotImplementedError): with pytest.raises(ValueError, match="Unsupported single-channel image dtype"):
image_array_to_pil_image(img_array) image_array_to_pil_image(img_array)
+5 -10
View File
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread( encoder_thread = _CameraEncoderThread(
video_path=video_path, video_path=video_path,
fps=fps, fps=fps,
vcodec=enc_cfg.vcodec, video_encoder=enc_cfg,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
frame_queue=frame_queue, frame_queue=frame_queue,
result_queue=result_queue, result_queue=result_queue,
stop_event=stop_event, stop_event=stop_event,
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread( encoder_thread = _CameraEncoderThread(
video_path=video_path, video_path=video_path,
fps=fps, fps=fps,
vcodec=enc_cfg.vcodec, video_encoder=enc_cfg,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
frame_queue=frame_queue, frame_queue=frame_queue,
result_queue=result_queue, result_queue=result_queue,
stop_event=stop_event, stop_event=stop_event,
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread( encoder_thread = _CameraEncoderThread(
video_path=video_path, video_path=video_path,
fps=fps, fps=fps,
vcodec=enc_cfg.vcodec, video_encoder=enc_cfg,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
frame_queue=frame_queue, frame_queue=frame_queue,
result_queue=result_queue, result_queue=result_queue,
stop_event=stop_event, stop_event=stop_event,
@@ -391,7 +385,8 @@ class TestStreamingVideoEncoder:
# Verify codec options include thread tuning for libsvtav1 (lp=…) # Verify codec options include thread tuning for libsvtav1 (lp=…)
thread = encoder._threads[f"{OBS_IMAGES}.cam"] thread = encoder._threads[f"{OBS_IMAGES}.cam"]
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options codec_opts = thread.video_encoder.get_codec_options(encoder_threads=thread.encoder_threads)
assert "svtav1-params" in codec_opts or "threads" in codec_opts
# Feed some frames and finish to ensure it works end-to-end # Feed some frames and finish to ensure it works end-to-end
num_frames = 10 num_frames = 10
+302 -73
View File
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
import av # noqa: E402 import av # noqa: E402
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
from lerobot.datasets.image_writer import write_image from lerobot.datasets.image_writer import write_image
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pyav_utils import get_codec from lerobot.datasets.pyav_utils import get_codec
@@ -37,7 +37,15 @@ from lerobot.datasets.video_utils import (
get_video_info, get_video_info,
reencode_video, reencode_video,
) )
from tests.fixtures.constants import DUMMY_VIDEO_INFO from tests.fixtures.constants import (
DUMMY_DEPTH_FEATURES,
DUMMY_DEPTH_KEY,
DUMMY_DEPTH_VIDEO_INFO_FULL,
DUMMY_VIDEO_FEATURES,
DUMMY_VIDEO_INFO,
DUMMY_VIDEO_KEY,
)
from tests.fixtures.dataset_factories import add_frames
# Per-codec skip markers — validation tests only fire when the codec is available # Per-codec skip markers — validation tests only fire when the codec is available
@@ -48,12 +56,67 @@ def _require_encoder(vcodec: str) -> pytest.MarkDecorator:
require_libsvtav1 = _require_encoder("libsvtav1") require_libsvtav1 = _require_encoder("libsvtav1")
require_h264 = _require_encoder("h264") require_h264 = _require_encoder("h264")
require_hevc = _require_encoder("hevc")
require_videotoolbox = _require_encoder("h264_videotoolbox") require_videotoolbox = _require_encoder("h264_videotoolbox")
require_nvenc = _require_encoder("h264_nvenc") require_nvenc = _require_encoder("h264_nvenc")
require_vaapi = _require_encoder("h264_vaapi") require_vaapi = _require_encoder("h264_vaapi")
require_qsv = _require_encoder("h264_qsv") require_qsv = _require_encoder("h264_qsv")
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
def _write_color_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
imgs_dir.mkdir(parents=True, exist_ok=True)
for i in range(num_frames):
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
def _write_depth_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
"""Write synthetic uint16 depth TIFFs (millimetres) for depth encoder tests.
Uses a smooth linear ramp + per-frame offset (not white noise) so HEVC Main 12
on ``gray12le`` compresses well. Values span ~100 mm to 10 m, covering most
of the default ``[DEPTH_MIN, DEPTH_MAX]`` metres range after
``quantize_depth(input_unit="auto"="mm")``.
"""
imgs_dir.mkdir(parents=True, exist_ok=True)
base = np.linspace(100.0, 10_000.0, height * width, dtype=np.float32).reshape(height, width)
for i in range(num_frames):
arr = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
write_image(arr, imgs_dir / f"frame-{i:06d}.tiff")
def _encode_video(
path: Path,
num_frames: int = 4,
fps: int = 30,
cfg: VideoEncoderConfig | None = None,
depth: bool = False,
) -> Path:
"""Write synthetic frames to a temp dir and encode them to ``path``.
``depth=False`` writes uint8 RGB PNG noise and encodes with ``cfg``
(defaulting to the library default). ``depth=True`` writes synthetic uint16
depth TIFFs and encodes with ``cfg`` or a default :class:`DepthEncoderConfig`
(HEVC Main 12 / ``gray12le``).
"""
imgs_dir = path.parent / f"imgs_{path.stem}"
if depth:
_write_depth_frames(imgs_dir, num_frames=num_frames)
cfg = cfg or DepthEncoderConfig()
else:
_write_color_frames(imgs_dir, num_frames=num_frames)
encode_video_frames(imgs_dir, path, fps=fps, video_encoder=cfg, overwrite=True)
return path
def _read_feature_info(dataset: LeRobotDataset, key: str = DUMMY_VIDEO_KEY) -> dict:
info = json.loads((dataset.root / INFO_PATH).read_text())
return info["features"][key]["info"]
# ─── VideoEncoderConfig / codec options ────────────────────────────── # ─── VideoEncoderConfig / codec options ──────────────────────────────
@@ -87,7 +150,7 @@ class TestCodecOptions:
assert opts["q:v"] == 40 assert opts["q:v"] == 40
assert "crf" not in opts assert "crf" not in opts
@_require_encoder("h264_nvenc") @require_nvenc
def test_nvenc_options(self): def test_nvenc_options(self):
cfg = VideoEncoderConfig(vcodec="h264_nvenc", g=2, crf=25, preset=None) cfg = VideoEncoderConfig(vcodec="h264_nvenc", g=2, crf=25, preset=None)
opts = cfg.get_codec_options() opts = cfg.get_codec_options()
@@ -96,12 +159,12 @@ class TestCodecOptions:
assert "crf" not in opts assert "crf" not in opts
assert opts["g"] == 2 assert opts["g"] == 2
@_require_encoder("h264_vaapi") @require_vaapi
def test_vaapi_options(self): def test_vaapi_options(self):
cfg = VideoEncoderConfig(vcodec="h264_vaapi", crf=28, preset=None) cfg = VideoEncoderConfig(vcodec="h264_vaapi", crf=28, preset=None)
assert cfg.get_codec_options()["qp"] == 28 assert cfg.get_codec_options()["qp"] == 28
@_require_encoder("h264_qsv") @require_qsv
def test_qsv_options(self): def test_qsv_options(self):
cfg = VideoEncoderConfig(vcodec="h264_qsv", crf=25, preset=None) cfg = VideoEncoderConfig(vcodec="h264_qsv", crf=25, preset=None)
assert cfg.get_codec_options()["global_quality"] == 25 assert cfg.get_codec_options()["global_quality"] == 25
@@ -313,59 +376,6 @@ class TestEncoderDetection:
assert "h264_nvenc" in VALID_VIDEO_CODECS assert "h264_nvenc" in VALID_VIDEO_CODECS
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
# Default video feature set used by persistence tests.
VIDEO_FEATURES = {
"observation.images.cam": {
"dtype": "video",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
},
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
}
VIDEO_KEY = "observation.images.cam"
def _write_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
imgs_dir.mkdir(parents=True, exist_ok=True)
for i in range(num_frames):
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
def _encode_video(
path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None
) -> Path:
imgs_dir = path.parent / f"imgs_{path.stem}"
_write_frames(imgs_dir, num_frames=num_frames)
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder=cfg, overwrite=True)
return path
def _read_feature_info(dataset: LeRobotDataset) -> dict:
info = json.loads((dataset.root / INFO_PATH).read_text())
return info["features"][VIDEO_KEY]["info"]
def _add_frames(dataset: LeRobotDataset, num_frames: int, video_keys: list[str] | None = None) -> None:
from lerobot.utils.constants import DEFAULT_FEATURES
if video_keys is None:
video_keys = dataset.meta.video_keys
for _ in range(num_frames):
frame: dict = {"task": "test"}
for key, ft in dataset.meta.features.items():
if key in DEFAULT_FEATURES:
continue
shape = ft["shape"]
if key in video_keys:
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
else:
frame[key] = np.zeros(shape, dtype=np.float32)
dataset.add_frame(frame)
class TestGetVideoInfo: class TestGetVideoInfo:
def test_returns_all_stream_fields(self): def test_returns_all_stream_fields(self):
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4") info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4")
@@ -375,7 +385,7 @@ class TestGetVideoInfo:
assert info["video.pix_fmt"] == "yuv420p" assert info["video.pix_fmt"] == "yuv420p"
assert info["video.fps"] == 30 assert info["video.fps"] == 30
assert info["video.channels"] == 3 assert info["video.channels"] == 3
assert info["video.is_depth_map"] is False assert info["is_depth_map"] is False
assert info["has_audio"] is False assert info["has_audio"] is False
assert "video.g" not in info assert "video.g" not in info
assert "video.crf" not in info assert "video.crf" not in info
@@ -385,7 +395,7 @@ class TestGetVideoInfo:
def test_merges_encoder_config_as_video_prefixed_entries(self): def test_merges_encoder_config_as_video_prefixed_entries(self):
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12) cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg) info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
assert info["video.g"] == 2 assert info["video.g"] == 2
assert info["video.crf"] == 30 assert info["video.crf"] == 30
@@ -398,11 +408,16 @@ class TestGetVideoInfo:
def test_stream_derived_keys_take_precedence_over_config(self): def test_stream_derived_keys_take_precedence_over_config(self):
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p") cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg) info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
assert info["video.codec"] # populated from stream, not from config's vcodec assert info["video.codec"] # populated from stream, not from config's vcodec
assert info["video.pix_fmt"] == "yuv420p" assert info["video.pix_fmt"] == "yuv420p"
def test_depth_encoder_config_sets_is_depth_map_true(self):
"""A ``DepthEncoderConfig`` causes ``get_video_info`` to mark the stream as depth."""
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=DepthEncoderConfig())
assert info["is_depth_map"] is True
class TestEncodeVideoFrames: class TestEncodeVideoFrames:
@require_libsvtav1 @require_libsvtav1
@@ -434,7 +449,7 @@ class TestEncodeVideoFrames:
def test_overwrite_false_skips_existing_file(self, tmp_path): def test_overwrite_false_skips_existing_file(self, tmp_path):
imgs_dir = tmp_path / "imgs" imgs_dir = tmp_path / "imgs"
_write_frames(imgs_dir) _write_color_frames(imgs_dir)
video_path = tmp_path / "out.mp4" video_path = tmp_path / "out.mp4"
sentinel = b"pre-existing content" sentinel = b"pre-existing content"
video_path.write_bytes(sentinel) video_path.write_bytes(sentinel)
@@ -446,7 +461,7 @@ class TestEncodeVideoFrames:
@require_libsvtav1 @require_libsvtav1
def test_overwrite_true_replaces_existing_file(self, tmp_path): def test_overwrite_true_replaces_existing_file(self, tmp_path):
imgs_dir = tmp_path / "imgs" imgs_dir = tmp_path / "imgs"
_write_frames(imgs_dir) _write_color_frames(imgs_dir)
video_path = tmp_path / "out.mp4" video_path = tmp_path / "out.mp4"
video_path.write_bytes(b"stale content") video_path.write_bytes(b"stale content")
@@ -461,7 +476,7 @@ class TestEncodeVideoFrames:
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10) cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg) video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
info = get_video_info(video_path, camera_encoder=cfg) info = get_video_info(video_path, video_encoder=cfg)
# Stream-derived # Stream-derived
assert info["video.height"] == 64 assert info["video.height"] == 64
@@ -470,7 +485,7 @@ class TestEncodeVideoFrames:
assert info["video.codec"] == "av1" assert info["video.codec"] == "av1"
assert info["video.pix_fmt"] == "yuv420p" assert info["video.pix_fmt"] == "yuv420p"
assert info["video.fps"] == 30 assert info["video.fps"] == 30
assert info["video.is_depth_map"] is False assert info["is_depth_map"] is False
assert info["has_audio"] is False assert info["has_audio"] is False
# Encoder config # Encoder config
assert info["video.g"] == 4 assert info["video.g"] == 4
@@ -488,14 +503,14 @@ class TestReencodeVideo:
src = TEST_ARTIFACTS_DIR / "clip_4frames.mp4" src = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
out = tmp_path / "reencoded.mp4" out = tmp_path / "reencoded.mp4"
cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv444p") cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv444p")
reencode_video(src, out, camera_encoder=cfg, overwrite=True) reencode_video(src, out, video_encoder=cfg, overwrite=True)
assert out.exists() assert out.exists()
with av.open(str(out)) as container: with av.open(str(out)) as container:
n_frames = sum(1 for _ in container.decode(video=0)) n_frames = sum(1 for _ in container.decode(video=0))
assert n_frames == 4 assert n_frames == 4
info = get_video_info(out, camera_encoder=cfg) info = get_video_info(out, video_encoder=cfg)
assert info["video.codec"] == "h264" assert info["video.codec"] == "h264"
assert info["video.pix_fmt"] == "yuv444p" assert info["video.pix_fmt"] == "yuv444p"
assert info["video.height"] == 64 assert info["video.height"] == 64
@@ -509,7 +524,7 @@ class TestReencodeVideo:
src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4" src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4"
out = tmp_path / "trim_window.mp4" out = tmp_path / "trim_window.mp4"
cfg = VideoEncoderConfig(vcodec="h264") cfg = VideoEncoderConfig(vcodec="h264")
reencode_video(src, out, camera_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True) reencode_video(src, out, video_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
with av.open(str(out)) as container: with av.open(str(out)) as container:
frames = list(container.decode(video=0)) frames = list(container.decode(video=0))
@@ -580,10 +595,10 @@ class TestEncoderConfigPersistence:
def test_first_episode_save_persists_encoder_config(self, tmp_path, empty_lerobot_dataset_factory): def test_first_episode_save_persists_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12) cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
dataset = empty_lerobot_dataset_factory( dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder=cfg root=tmp_path / "ds", features=DUMMY_VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
) )
_add_frames(dataset, num_frames=4) add_frames(dataset, num_frames=4)
dataset.save_episode() dataset.save_episode()
dataset.finalize() dataset.finalize()
@@ -603,14 +618,14 @@ class TestEncoderConfigPersistence:
def test_second_episode_does_not_overwrite_encoder_fields(self, tmp_path, empty_lerobot_dataset_factory): def test_second_episode_does_not_overwrite_encoder_fields(self, tmp_path, empty_lerobot_dataset_factory):
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12) cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
dataset = empty_lerobot_dataset_factory( dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder=cfg root=tmp_path / "ds", features=DUMMY_VIDEO_FEATURES, use_videos=True, camera_encoder=cfg
) )
_add_frames(dataset, num_frames=4) add_frames(dataset, num_frames=4)
dataset.save_episode() dataset.save_episode()
first_info = dict(_read_feature_info(dataset)) first_info = dict(_read_feature_info(dataset))
_add_frames(dataset, num_frames=4) add_frames(dataset, num_frames=4)
dataset.save_episode() dataset.save_episode()
dataset.finalize() dataset.finalize()
@@ -637,3 +652,217 @@ class TestFromVideoInfo:
# ``{}`` placeholder (typical after a merge with disagreeing sources) # ``{}`` placeholder (typical after a merge with disagreeing sources)
# must not leak into the reconstructed config. # must not leak into the reconstructed config.
assert cfg.extra_options == VideoEncoderConfig().extra_options assert cfg.extra_options == VideoEncoderConfig().extra_options
# ─── Depth-specific encoding tests ────────────────────────────────────
class TestEncodeDepthVideoFrames:
"""Depth mirror of :class:`TestEncodeVideoFrames`.
Exercises ``encode_video_frames`` end-to-end through
:class:`DepthEncoderConfig` (HEVC Main 12 / ``gray12le``) on synthetic
uint16 depth TIFFs.
"""
@require_hevc
def test_produces_readable_file(self, tmp_path):
video_path = _encode_video(tmp_path / "out.mp4", depth=True)
assert video_path.exists()
info = get_video_info(video_path, video_encoder=DepthEncoderConfig())
assert info["video.height"] == 64
assert info["video.width"] == 96
assert info["video.codec"] == "hevc"
assert info["video.pix_fmt"] == "gray12le"
assert info["video.channels"] == 1
assert info["is_depth_map"] is True
@require_hevc
def test_frame_count_and_duration_match_input(self, tmp_path):
num_frames = 10
fps = 30
video_path = _encode_video(tmp_path / "out.mp4", num_frames=num_frames, fps=fps, depth=True)
with av.open(str(video_path)) as container:
stream = container.streams.video[0]
actual_frames = sum(1 for _ in container.decode(stream))
duration = (
float(stream.duration * stream.time_base)
if stream.duration is not None
else float(container.duration / av.time_base)
)
assert actual_frames == num_frames
assert abs(duration - num_frames / fps) < 0.1
def test_overwrite_false_skips_existing_file(self, tmp_path):
"""Codec-agnostic: file-system semantics must hold even without an HEVC encoder."""
imgs_dir = tmp_path / "imgs"
_write_depth_frames(imgs_dir)
video_path = tmp_path / "out.mp4"
sentinel = b"pre-existing depth content"
video_path.write_bytes(sentinel)
encode_video_frames(imgs_dir, video_path, fps=30, video_encoder=DepthEncoderConfig(), overwrite=False)
assert video_path.read_bytes() == sentinel
@require_hevc
def test_overwrite_true_replaces_existing_file(self, tmp_path):
imgs_dir = tmp_path / "imgs"
_write_depth_frames(imgs_dir)
video_path = tmp_path / "out.mp4"
video_path.write_bytes(b"stale content")
encode_video_frames(imgs_dir, video_path, fps=30, video_encoder=DepthEncoderConfig(), overwrite=True)
info = get_video_info(video_path, video_encoder=DepthEncoderConfig())
assert info["video.height"] == 64
assert info["video.pix_fmt"] == "gray12le"
assert info["is_depth_map"] is True
@require_hevc
def test_custom_encoder_config_fields_stored_in_info(self, tmp_path):
"""All stream-derived and depth-encoder config fields are present after encoding."""
cfg = DepthEncoderConfig(
vcodec="hevc",
pix_fmt="gray12le",
g=4,
crf=25,
depth_min=0.05,
depth_max=8.0,
shift=2.5,
use_log=False,
)
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg, depth=True)
info = get_video_info(video_path, video_encoder=cfg)
# Stream-derived
assert info["video.height"] == 64
assert info["video.width"] == 96
assert info["video.channels"] == 1
assert info["video.codec"] == "hevc"
assert info["video.pix_fmt"] == "gray12le"
assert info["video.fps"] == 30
assert info["is_depth_map"] is True
assert info["has_audio"] is False
# Base encoder config
assert info["video.g"] == 4
assert info["video.crf"] == 25
assert info["video.fast_decode"] == 0
assert info["video.video_backend"] == "pyav"
assert info["video.extra_options"] == {}
# Depth-specific tuning
assert info["video.depth_min"] == 0.05
assert info["video.depth_max"] == 8.0
assert info["video.shift"] == 2.5
assert info["video.use_log"] is False
class TestDepthEncoderConfigPersistence:
"""Depth mirror of :class:`TestEncoderConfigPersistence`.
``DepthEncoderConfig`` must be stored as ``video.<field>`` entries
(including the depth-specific ``depth_min`` / ``depth_max`` / ``shift`` /
``use_log``) under ``info["features"][<depth_key>]["info"]`` when the
first episode is saved.
"""
@require_hevc
def test_first_episode_save_persists_depth_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
cfg = DepthEncoderConfig(
vcodec="hevc",
pix_fmt="gray12le",
g=2,
crf=30,
depth_min=0.05,
depth_max=8.0,
shift=2.5,
use_log=False,
)
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds", features=DUMMY_DEPTH_FEATURES, use_videos=True, depth_encoder=cfg
)
add_frames(dataset, num_frames=4)
dataset.save_episode()
dataset.finalize()
info = _read_feature_info(dataset, key=DUMMY_DEPTH_KEY)
# Stream-derived
assert info["video.height"] == 64
assert info["video.width"] == 96
assert info["video.fps"] == 30
assert info["video.codec"] == "hevc"
assert info["video.pix_fmt"] == "gray12le"
assert info["is_depth_map"] is True
# Base encoder config
assert info["video.g"] == 2
assert info["video.crf"] == 30
assert info["video.fast_decode"] == 0
assert info["video.video_backend"] == "pyav"
assert info["video.extra_options"] == {}
# Depth-specific tuning
assert info["video.depth_min"] == 0.05
assert info["video.depth_max"] == 8.0
assert info["video.shift"] == 2.5
assert info["video.use_log"] is False
@require_hevc
def test_second_episode_does_not_overwrite_depth_encoder_fields(
self, tmp_path, empty_lerobot_dataset_factory
):
cfg = DepthEncoderConfig(
vcodec="hevc",
pix_fmt="gray12le",
g=2,
crf=30,
depth_min=0.05,
depth_max=8.0,
shift=2.5,
use_log=False,
)
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds", features=DUMMY_DEPTH_FEATURES, use_videos=True, depth_encoder=cfg
)
add_frames(dataset, num_frames=4)
dataset.save_episode()
first_info = dict(_read_feature_info(dataset, key=DUMMY_DEPTH_KEY))
add_frames(dataset, num_frames=4)
dataset.save_episode()
dataset.finalize()
assert _read_feature_info(dataset, key=DUMMY_DEPTH_KEY) == first_info
class TestDepthFromVideoInfo:
"""``DepthEncoderConfig.from_video_info`` reconstructs a depth encoder
config from the ``video.*`` keys persisted in a dataset's ``info.json``.
Depth mirror of :class:`TestFromVideoInfo`.
"""
@require_hevc
def test_reconstructs_from_dummy_depth_video_info(self):
cfg = DepthEncoderConfig.from_video_info(DUMMY_DEPTH_VIDEO_INFO_FULL)
# No alias for ``"hevc"``; the canonical stream codec is reused as-is.
assert cfg.vcodec == "hevc"
assert cfg.pix_fmt == DUMMY_DEPTH_VIDEO_INFO_FULL["video.pix_fmt"]
assert cfg.g == DUMMY_DEPTH_VIDEO_INFO_FULL["video.g"]
assert cfg.crf == DUMMY_DEPTH_VIDEO_INFO_FULL["video.crf"]
assert cfg.fast_decode == DUMMY_DEPTH_VIDEO_INFO_FULL["video.fast_decode"]
assert cfg.video_backend == DUMMY_DEPTH_VIDEO_INFO_FULL["video.video_backend"]
# ``{}`` placeholder (typical after a merge with disagreeing sources)
# must not leak into the reconstructed config.
assert cfg.extra_options == DepthEncoderConfig().extra_options
# Depth-specific tuning round-trips through ``info.json``.
assert cfg.depth_min == DUMMY_DEPTH_VIDEO_INFO_FULL["video.depth_min"]
assert cfg.depth_max == DUMMY_DEPTH_VIDEO_INFO_FULL["video.depth_max"]
assert cfg.shift == DUMMY_DEPTH_VIDEO_INFO_FULL["video.shift"]
assert cfg.use_log == DUMMY_DEPTH_VIDEO_INFO_FULL["video.use_log"]
+45 -1
View File
@@ -39,12 +39,56 @@ DUMMY_VIDEO_INFO = {
"video.crf": 30, "video.crf": 30,
"video.preset": 12, "video.preset": 12,
"video.fast_decode": 0, "video.fast_decode": 0,
"video.is_depth_map": False, "is_depth_map": False,
"has_audio": False, "has_audio": False,
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO}, "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO}, "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
} }
DUMMY_DEPTH_VIDEO_INFO = {
**DUMMY_VIDEO_INFO,
"is_depth_map": True,
}
DUMMY_DEPTH_VIDEO_INFO_FULL = {
**{k: v for k, v in DUMMY_VIDEO_INFO.items() if k != "video.preset"},
"video.codec": "hevc",
"video.pix_fmt": "gray12le",
"is_depth_map": True,
"video.depth_min": 0.05,
"video.depth_max": 8.0,
"video.shift": 2.5,
"video.use_log": True,
}
DUMMY_DEPTH_CAMERA_FEATURES = {
"laptop_depth": {
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": DUMMY_DEPTH_VIDEO_INFO,
},
}
DUMMY_CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES}
DUMMY_CHW = (3, 96, 128) DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3) DUMMY_HWC = (96, 128, 3)
# Default video feature set used by video-encoding persistence tests.
DUMMY_VIDEO_FEATURES = {
"observation.images.cam": {
"dtype": "video",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
},
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
}
DUMMY_VIDEO_KEY = "observation.images.cam"
DUMMY_DEPTH_FEATURES = {
"observation.images.depth": {
"dtype": "video",
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": {"is_depth_map": True},
},
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
}
DUMMY_DEPTH_KEY = "observation.images.depth"
+38
View File
@@ -49,6 +49,39 @@ from tests.fixtures.constants import (
) )
def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
"""Append ``num_frames`` synthetic frames to ``dataset``.
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
"""
video_keys = dataset.meta.video_keys
depth_keys = dataset.meta.depth_keys
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
# encode (HEVC Main 12 hates white noise).
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
for i in range(num_frames):
frame: dict = {"task": "test"}
for key, ft in dataset.meta.features.items():
if key in DEFAULT_FEATURES:
continue
shape = ft["shape"]
if key in depth_keys:
h, w, _ = shape
base = _depth_base_cache.setdefault(
(h, w),
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
)
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
elif key in video_keys:
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
else:
frame[key] = np.zeros(shape, dtype=np.float32)
dataset.add_frame(frame)
class LeRobotDatasetFactory(Protocol): class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ... def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
@@ -485,10 +518,14 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
chunks_size: int = DEFAULT_CHUNK_SIZE, chunks_size: int = DEFAULT_CHUNK_SIZE,
camera_features: dict | None = None,
**kwargs, **kwargs,
) -> LeRobotDataset: ) -> LeRobotDataset:
# Instantiate objects # Instantiate objects
if info is None: if info is None:
info_kwargs = {}
if camera_features is not None:
info_kwargs["camera_features"] = camera_features
info = info_factory( info = info_factory(
total_episodes=total_episodes, total_episodes=total_episodes,
total_frames=total_frames, total_frames=total_frames,
@@ -496,6 +533,7 @@ def lerobot_dataset_factory(
use_videos=use_videos, use_videos=use_videos,
data_files_size_in_mb=data_files_size_in_mb, data_files_size_in_mb=data_files_size_in_mb,
chunks_size=chunks_size, chunks_size=chunks_size,
**info_kwargs,
) )
if stats is None: if stats is None:
stats = stats_factory(features=info.features) stats = stats_factory(features=info.features)
+21 -3
View File
@@ -2370,14 +2370,32 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features( out = aggregate_pipeline_dataset_features(
pipeline=rp, pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # expect "image" dtype use_videos=False, # images kept, stored as "image" dtype
patterns=None, patterns=None,
) )
key = f"{OBS_IMAGES}.back" key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front" key_front = f"{OBS_IMAGES}.front"
assert key not in out assert key in out
assert key_front not in out assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
def test_aggregate_images_when_use_videos_true(): def test_aggregate_images_when_use_videos_true():
@@ -27,6 +27,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
MergeConfig, MergeConfig,
ModifyTasksConfig, ModifyTasksConfig,
OperationConfig, OperationConfig,
ReencodeVideosConfig,
RemoveFeatureConfig, RemoveFeatureConfig,
SplitConfig, SplitConfig,
_validate_config, _validate_config,
@@ -103,3 +104,47 @@ class TestOperationTypeParsing:
) )
resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
assert resolved_name == type_name assert resolved_name == type_name
class TestDepthEncoderParsing:
"""Test that the depth encoder is exposed and parsed for video operations."""
def test_reencode_has_default_depth_encoder(self):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", "reencode_videos"])
assert isinstance(cfg.operation, ReencodeVideosConfig)
# A depth encoder is configured by default so depth videos are re-encoded too.
assert cfg.operation.depth_encoder is not None
assert hasattr(cfg.operation.depth_encoder, "depth_min")
def test_reencode_parses_depth_encoder_overrides(self):
cfg = parse_cfg(
[
"--repo_id",
"test/repo",
"--operation.type",
"reencode_videos",
"--operation.depth_encoder.vcodec",
"ffv1",
"--operation.depth_encoder.depth_max",
"12.0",
"--operation.depth_encoder.use_log",
"false",
]
)
assert cfg.operation.depth_encoder.vcodec == "ffv1"
assert cfg.operation.depth_encoder.depth_max == 12.0
assert cfg.operation.depth_encoder.use_log is False
def test_convert_image_to_video_parses_depth_encoder_overrides(self):
cfg = parse_cfg(
[
"--repo_id",
"test/repo",
"--operation.type",
"convert_image_to_video",
"--operation.depth_encoder.depth_min",
"0.05",
]
)
assert isinstance(cfg.operation, ConvertImageToVideoConfig)
assert cfg.operation.depth_encoder.depth_min == 0.05
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import ( from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader, RebotArm102Leader,
RebotArm102LeaderConfig, RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features(): def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None): with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotArm102LeaderConfig( cfg = BiRebot102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"), left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"), right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
) )
teleop = BiRebotArm102Leader(cfg) teleop = BiRebot102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features) assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features) assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features assert "left_gripper.pos" in teleop.action_features
+8 -1
View File
@@ -43,6 +43,11 @@ def mock_rerun(monkeypatch):
def __init__(self, arr): def __init__(self, arr):
self.arr = arr self.arr = arr
class DummyDepthImage:
def __init__(self, arr, colormap=None):
self.arr = arr
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs): def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs. # Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs: if obj is None and "entity" in kwargs:
@@ -55,6 +60,8 @@ def mock_rerun(monkeypatch):
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None), __spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar, Scalars=DummyScalar,
Image=DummyImage, Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log, log=dummy_log,
init=lambda *a, **k: None, init=lambda *a, **k: None,
spawn=lambda *a, **k: None, spawn=lambda *a, **k: None,
@@ -225,7 +232,7 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
assert temp.value == pytest.approx(10.0) assert temp.value == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray") img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyImage" assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True assert _kwargs_for(calls, "observation.gray").get("static", False) is True
Generated
+949 -900
View File
File diff suppressed because it is too large Load Diff