mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 23:59:43 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e6952150f | |||
| 017ff73fbf | |||
| f90db58c15 | |||
| e64fa667c3 | |||
| d9ec3a6fa2 | |||
| d90e4bcfd3 | |||
| 9d3b62aa61 |
@@ -19,6 +19,8 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
|
||||
@@ -204,22 +204,26 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
|
||||
|
||||
Your dataset includes:
|
||||
|
||||
**Your Actions (2 things)**:
|
||||
**Your Actions (2 features)**:
|
||||
|
||||
- How much you moved forward/backward
|
||||
- How much you turned left/right
|
||||
- `linear_velocity`: How much you moved forward/backward
|
||||
- `angular_velocity`: How much you turned left/right
|
||||
|
||||
**Robot Observations (12 things)**:
|
||||
**Robot Observations (24 features)**:
|
||||
|
||||
- Front camera video
|
||||
- Rear camera video
|
||||
- Current speed
|
||||
- Battery level
|
||||
- Which way the robot is facing
|
||||
- GPS location (latitude, longitude, signal strength)
|
||||
- Orientation
|
||||
- GPS (latitude, longitude, signal strength)
|
||||
- Network signal strength
|
||||
- Vibration level
|
||||
- Lamp status (on/off)
|
||||
- Lamp state (on/off)
|
||||
- Accelerometer (x, y, z)
|
||||
- Gyroscope (x, y, z)
|
||||
- Magnetometer (x, y, z)
|
||||
- Wheel RPMs (4 wheels)
|
||||
|
||||
### Where Your Data Goes
|
||||
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# Rename Map and Empty Cameras
|
||||
|
||||
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
|
||||
|
||||
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
|
||||
|
||||
## Why observation keys don't always match
|
||||
|
||||
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
|
||||
|
||||
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
|
||||
|
||||
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
|
||||
|
||||
## Using the rename map
|
||||
|
||||
Pass the mapping as a JSON string on the command line. The convention is always:
|
||||
|
||||
```
|
||||
--rename_map='{"source_key": "policy_key", ...}'
|
||||
```
|
||||
|
||||
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
|
||||
|
||||
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
|
||||
|
||||
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
|
||||
|
||||
### Training
|
||||
|
||||
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--env.type=libero \
|
||||
... \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
|
||||
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
|
||||
|
||||
## Empty cameras: fewer views than the policy expects
|
||||
|
||||
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
|
||||
|
||||
### How it works
|
||||
|
||||
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
|
||||
|
||||
```
|
||||
observation.images.empty_camera_0
|
||||
observation.images.empty_camera_1
|
||||
...
|
||||
```
|
||||
|
||||
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
|
||||
|
||||
### Example
|
||||
|
||||
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
|
||||
|
||||
1. Set `--policy.empty_cameras=1`.
|
||||
2. The config adds a third key: `observation.images.empty_camera_0`.
|
||||
3. Use the rename map for your two real cameras as usual.
|
||||
4. The third slot is masked out — no fake images needed in your dataset.
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
@@ -32,7 +32,8 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
|
||||
@@ -22,7 +22,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||
|
||||
DROID_SHARDS = 2048
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import HfApi
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -155,7 +155,7 @@ class UploadDataset(PipelineStep):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
@@ -113,8 +113,9 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
|
||||
@@ -82,7 +82,7 @@ from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
|
||||
@@ -19,8 +19,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -20,9 +20,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -5,8 +5,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -5,8 +5,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -6,8 +6,8 @@ from queue import Empty, Full
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -0,0 +1,956 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
SONIC planner with full mode control.
|
||||
|
||||
Keyboard controls:
|
||||
N / P - next / previous motion set
|
||||
1-8 - select mode within current set
|
||||
WASD - movement direction
|
||||
Q / E - rotate facing left / right
|
||||
9 / 0 - decrease / increase speed
|
||||
- / = - decrease / increase height
|
||||
R - force replan
|
||||
Space - emergency stop -> IDLE
|
||||
Esc - quit
|
||||
|
||||
Gamepad controls (Unitree wireless controller):
|
||||
Left stick Y - speed (forward = fast, back = stop)
|
||||
Left stick X - movement direction (offset from facing)
|
||||
Right stick X - facing direction (incremental rotation)
|
||||
Right stick Y - height (up = tall 0.8m, down = low 0.1m)
|
||||
Buttons - unused (mode selection is keyboard-only)
|
||||
"""
|
||||
|
||||
import argparse, gc, math, select, sys, termios, tty
|
||||
import multiprocessing as mp
|
||||
import threading, time
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
|
||||
], dtype=np.float32)
|
||||
|
||||
NATURAL_FREQ = 10.0 * 2.0 * np.pi
|
||||
ARMATURE = {"5020": 0.003609725, "7520_14": 0.010177520, "7520_22": 0.025101925, "4010": 0.00425}
|
||||
EFFORT = {"5020": 25.0, "7520_14": 88.0, "7520_22": 139.0, "4010": 5.0}
|
||||
|
||||
def _action_scale(k):
|
||||
return 0.25 * EFFORT[k] / (ARMATURE[k] * NATURAL_FREQ**2)
|
||||
|
||||
_J = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
|
||||
["7520_14","5020","5020"] + \
|
||||
["5020","5020","5020","5020","5020","4010","4010"] * 2
|
||||
ACTION_SCALE = np.array([_action_scale(k) for k in _J], dtype=np.float32)
|
||||
|
||||
CONTROL_DT = 0.02
|
||||
DEFAULT_HEIGHT = 0.788740
|
||||
TOKEN_DIM = 64
|
||||
ENCODER_UPDATE_EVERY = 5
|
||||
DEBUG_PRINT_EVERY = 100
|
||||
MOTION_LOOK_AHEAD_STEPS = 2
|
||||
INITIAL_RANDOM_SEED = 1234
|
||||
MIN_TOKENS, MAX_TOKENS = 6, 16
|
||||
K = MAX_TOKENS - MIN_TOKENS + 1
|
||||
DEADZONE = 0.05
|
||||
BLEND_FRAMES = 8
|
||||
|
||||
REPLAN_INTERVAL = {
|
||||
"running": 0.1, "crawling": 0.2, "boxing": 1.0, "default": 1.0
|
||||
}
|
||||
|
||||
ISAACLAB_TO_MUJOCO = np.array([
|
||||
0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8,
|
||||
11, 15, 19, 21, 23, 25, 27, 12, 16, 20, 22, 24, 26, 28
|
||||
], dtype=np.int32)
|
||||
|
||||
MUJOCO_TO_ISAACLAB = np.array([
|
||||
0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10,
|
||||
16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28
|
||||
], dtype=np.int32)
|
||||
|
||||
def _to_mujoco(a): return a[MUJOCO_TO_ISAACLAB]
|
||||
def _to_runtime(a): r = np.zeros(29, np.float32); r[MUJOCO_TO_ISAACLAB] = a; return r
|
||||
|
||||
DEFAULT_ANGLES_MUJOCO = _to_mujoco(DEFAULT_ANGLES)
|
||||
ENCODER_STANDING_REF = DEFAULT_ANGLES.copy()
|
||||
|
||||
LOWER_BODY_IL = np.array([0,3,6,9,13,17,1,4,7,10,14,18], dtype=np.int32)
|
||||
WRIST_IL = np.array([23,24,25,26,27,28], dtype=np.int32)
|
||||
VR_TARGET_DEF = np.zeros(9, dtype=np.float32)
|
||||
VR_ORN_DEF = np.array([1,0,0,0,1,0,0,0,1,0,0,0], dtype=np.float32)
|
||||
SMPL_DEF = np.zeros(720, dtype=np.float32)
|
||||
|
||||
# ── PD gains ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def _kp_kd():
|
||||
s = lambda k: ARMATURE[k] * NATURAL_FREQ**2
|
||||
d = lambda k: 2.0 * 2.0 * ARMATURE[k] * NATURAL_FREQ
|
||||
_kp_keys = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
|
||||
["7520_14","5020","5020"] + \
|
||||
["5020","5020","5020","5020","5020","4010","4010"] * 2
|
||||
_kd_keys = _kp_keys
|
||||
_double = {4,5,10,11,13,14} # ankle + waist indices with factor 2
|
||||
kp = np.array([2*s(k) if i in _double else s(k) for i,k in enumerate(_kp_keys)], dtype=np.float32)
|
||||
kd = np.array([2*d(k) if i in _double else d(k) for i,k in enumerate(_kd_keys)], dtype=np.float32)
|
||||
return kp, kd
|
||||
|
||||
# ── Quaternion helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def quat_conj(q):
|
||||
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
|
||||
|
||||
def quat_mul(q1, q2):
|
||||
w1,x1,y1,z1 = q1; w2,x2,y2,z2 = q2
|
||||
return np.array([
|
||||
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
||||
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
||||
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
||||
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
||||
], dtype=np.float32)
|
||||
|
||||
def gravity_dir(q):
|
||||
q = q / (np.linalg.norm(q) + 1e-8)
|
||||
qv = np.array([0, 0, 0, -1], dtype=np.float32)
|
||||
return quat_mul(quat_mul(quat_conj(q), qv), q)[1:]
|
||||
|
||||
def quat_to_6d(q):
|
||||
w,x,y,z = q
|
||||
return np.array([
|
||||
1-2*(y*y+z*z), 2*(x*y-z*w),
|
||||
2*(x*y+z*w), 1-2*(x*x+z*z),
|
||||
2*(x*z-y*w), 2*(y*z+x*w),
|
||||
], dtype=np.float32)
|
||||
|
||||
def calc_heading(q):
|
||||
w,x,y,z = q
|
||||
return float(np.arctan2(2*(x*y + w*z), 1-2*(y*y+z*z)))
|
||||
|
||||
def heading_quat(q, sign=1.0):
|
||||
a = sign * calc_heading(q) / 2.0
|
||||
return np.array([np.cos(a), 0, 0, np.sin(a)], dtype=np.float64)
|
||||
|
||||
heading_quat_inv = lambda q: heading_quat(q, -1.0)
|
||||
|
||||
def quat_slerp(q0, q1, t):
|
||||
q0 = q0 / (np.linalg.norm(q0)+1e-12); q1 = q1 / (np.linalg.norm(q1)+1e-12)
|
||||
dot = float(np.dot(q0, q1))
|
||||
if dot < 0: q1, dot = -q1, -dot
|
||||
dot = min(dot, 1.0)
|
||||
if dot > 0.9995:
|
||||
r = q0 + t*(q1-q0); return r/(np.linalg.norm(r)+1e-12)
|
||||
th = np.arccos(dot); st = np.sin(th)
|
||||
return (np.sin((1-t)*th)/st)*q0 + (np.sin(t*th)/st)*q1
|
||||
|
||||
def quat_slerp_batch(q0, q1, t):
|
||||
q0 = q0 / (np.linalg.norm(q0,axis=1,keepdims=True)+1e-12)
|
||||
q1 = q1 / (np.linalg.norm(q1,axis=1,keepdims=True)+1e-12)
|
||||
dot = np.sum(q0*q1, axis=1); neg = dot<0
|
||||
q1=q1.copy(); q1[neg]=-q1[neg]; dot[neg]=-dot[neg]; dot=np.clip(dot,-1,1)
|
||||
lin = dot>0.9995; th=np.arccos(dot); st=np.where(np.sin(th)==0,1,np.sin(th))
|
||||
c0=np.sin((1-t)*th)/st; c1=np.sin(t*th)/st
|
||||
c0[lin]=1-t[lin]; c1[lin]=t[lin]
|
||||
r = c0[:,None]*q0 + c1[:,None]*q1
|
||||
return r / (np.linalg.norm(r,axis=1,keepdims=True)+1e-12)
|
||||
|
||||
# ── Locomotion modes ──────────────────────────────────────────────────────────
|
||||
|
||||
class LocomotionMode(IntEnum):
|
||||
IDLE=0; SLOW_WALK=1; WALK=2; RUN=3; SQUAT=4; KNEEL_TWO_LEGS=5; KNEEL=6
|
||||
LYING_FACE_DOWN=7; CRAWLING=8; IDLE_BOXING=9; WALK_BOXING=10
|
||||
LEFT_PUNCH=11; RIGHT_PUNCH=12; RANDOM_PUNCH=13; ELBOW_CRAWLING=14
|
||||
LEFT_HOOK=15; RIGHT_HOOK=16; FORWARD_JUMP=17; STEALTH_WALK=18
|
||||
INJURED_WALK=19; LEDGE_WALKING=20; OBJECT_CARRYING=21; STEALTH_WALK_2=22
|
||||
HAPPY_DANCE_WALK=23; ZOMBIE_WALK=24; GUN_WALK=25; SCARE_WALK=26
|
||||
|
||||
LM = LocomotionMode
|
||||
|
||||
MOTION_SETS = [
|
||||
("Standing", [LM.SLOW_WALK, LM.WALK, LM.RUN, LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK]),
|
||||
("Squat / Low", [LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.CRAWLING, LM.ELBOW_CRAWLING]),
|
||||
("Boxing", [LM.IDLE_BOXING, LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
|
||||
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK]),
|
||||
("Styled Walks", [LM.LEDGE_WALKING, LM.OBJECT_CARRYING, LM.STEALTH_WALK_2,
|
||||
LM.HAPPY_DANCE_WALK, LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK]),
|
||||
]
|
||||
|
||||
STATIC_MODES = {LM.IDLE, LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.LYING_FACE_DOWN, LM.IDLE_BOXING}
|
||||
STANDING_MODES = {LM.IDLE, LM.SLOW_WALK, LM.WALK, LM.RUN, LM.IDLE_BOXING, LM.WALK_BOXING,
|
||||
LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK,
|
||||
LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK, LM.LEDGE_WALKING,
|
||||
LM.OBJECT_CARRYING, LM.STEALTH_WALK_2, LM.HAPPY_DANCE_WALK,
|
||||
LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK}
|
||||
BOXING_MODES = {LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
|
||||
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}
|
||||
SPEED_RANGES = {LM.SLOW_WALK:(0.2,0.8), LM.WALK:(0.8,1.5), LM.RUN:(1.5,3.0),
|
||||
LM.CRAWLING:(0.4,1.0), LM.ELBOW_CRAWLING:(0.7,1.0)}
|
||||
|
||||
def clamp_mode_params(ms):
|
||||
m = LM(ms.mode)
|
||||
ms.height = -1.0 if m in STANDING_MODES else max(0.1, min(0.8, ms.height if ms.height>=0 else 0.2))
|
||||
if m in STATIC_MODES:
|
||||
ms.speed = -1.0
|
||||
elif m in SPEED_RANGES:
|
||||
lo, hi = SPEED_RANGES[m]
|
||||
ms.speed = max(lo, min(hi, ms.speed if ms.speed>=0 else lo))
|
||||
elif m in BOXING_MODES:
|
||||
ms.speed = max(0.7, min(1.5, ms.speed if ms.speed>=0 else 0.7))
|
||||
else:
|
||||
ms.speed = -1.0
|
||||
|
||||
def replan_interval(mode):
|
||||
m = LM(mode)
|
||||
if m == LM.RUN: return REPLAN_INTERVAL["running"]
|
||||
if m == LM.CRAWLING: return REPLAN_INTERVAL["crawling"]
|
||||
if m in {LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}:
|
||||
return REPLAN_INTERVAL["boxing"]
|
||||
return REPLAN_INTERVAL["default"]
|
||||
|
||||
# ── Movement state ────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class MovementState:
|
||||
mode: int = 0
|
||||
speed: float = -1.0
|
||||
height: float = -1.0
|
||||
facing_angle: float = 0.0
|
||||
movement_angle: float = 0.0
|
||||
has_movement: bool = False
|
||||
motion_set_idx: int = 0
|
||||
needs_replan: bool = False
|
||||
|
||||
@property
|
||||
def movement_direction(self):
|
||||
if not self.has_movement: return (0.0, 0.0, 0.0)
|
||||
return (math.cos(self.movement_angle), math.sin(self.movement_angle), 0.0)
|
||||
|
||||
@property
|
||||
def facing_direction(self):
|
||||
return (math.cos(self.facing_angle), math.sin(self.facing_angle), 0.0)
|
||||
|
||||
def status_line(self):
|
||||
return (f"[{MOTION_SETS[self.motion_set_idx][0]}] mode={self.mode}({LM(self.mode).name}) "
|
||||
f"spd={'default' if self.speed<0 else f'{self.speed:.1f}'} "
|
||||
f"hgt={'default' if self.height<0 else f'{self.height:.2f}'} "
|
||||
f"facing={math.degrees(self.facing_angle):.0f}° "
|
||||
f"{'moving' if self.has_movement else 'still'}")
|
||||
|
||||
# ── Encoder / Decoder ─────────────────────────────────────────────────────────
|
||||
|
||||
class StandingEncoderDecoder:
|
||||
def __init__(self, encoder, decoder):
|
||||
self.encoder, self.decoder = encoder, decoder
|
||||
self.encoder_input = encoder.get_inputs()[0].name
|
||||
self.decoder_input = decoder.get_inputs()[0].name
|
||||
enc_dim = int(encoder.get_inputs()[0].shape[1])
|
||||
dec_dim = int(decoder.get_inputs()[0].shape[1])
|
||||
if enc_dim != 1762 or dec_dim != 994:
|
||||
raise RuntimeError(f"Unexpected dims encoder={enc_dim}, decoder={dec_dim}")
|
||||
self.token = np.zeros(TOKEN_DIM, np.float32)
|
||||
self.last_action_mj = np.zeros(29, np.float32)
|
||||
self.h_q_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_dq_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_ang = [np.zeros(3, np.float32)] * 10
|
||||
self.h_act_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_quat = [np.array([1,0,0,0], np.float32)] * 10
|
||||
self.init_base_quat = np.array([1,0,0,0], np.float32)
|
||||
self.init_ref_quat = np.array([1,0,0,0], np.float32)
|
||||
self._heading_init = False
|
||||
self.encode_mode = 0
|
||||
self.vr_3point_local_target = VR_TARGET_DEF.copy()
|
||||
self.vr_3point_local_orn_target = VR_ORN_DEF.copy()
|
||||
self.smpl_joints_10frame_step1 = SMPL_DEF.copy()
|
||||
self.set_zero_reference()
|
||||
|
||||
def update_history(self, q, dq, ang, quat):
|
||||
quat = quat / (np.linalg.norm(quat)+1e-8)
|
||||
q_mj = _to_mujoco(q); dq_mj = _to_mujoco(dq)
|
||||
self.h_q_mj = [q_mj - DEFAULT_ANGLES_MUJOCO] + self.h_q_mj[:-1]
|
||||
self.h_dq_mj = [dq_mj] + self.h_dq_mj[:-1]
|
||||
self.h_ang = [ang.copy()] + self.h_ang[:-1]
|
||||
self.h_act_mj = [self.last_action_mj.copy()] + self.h_act_mj[:-1]
|
||||
self.h_quat = [quat.copy()] + self.h_quat[:-1]
|
||||
if not self._heading_init:
|
||||
self.init_base_quat = quat.copy(); self._heading_init = True
|
||||
|
||||
def _heading_quat(self, q):
|
||||
h = calc_heading(q) / 2.0
|
||||
return np.array([np.cos(h), 0, 0, np.sin(h)], np.float32)
|
||||
|
||||
def _heading_quat_inv(self, q):
|
||||
h = calc_heading(q) / 2.0
|
||||
return np.array([np.cos(-h), 0, 0, np.sin(-h)], np.float32)
|
||||
|
||||
def _anchor_6d(self, base_quat, ref_quat=None):
|
||||
if ref_quat is None: ref_quat = self.init_ref_quat
|
||||
delta = quat_mul(self._heading_quat(self.init_base_quat), self._heading_quat_inv(self.init_ref_quat))
|
||||
new_ref = quat_mul(delta, ref_quat)
|
||||
return quat_to_6d(quat_mul(quat_conj(base_quat), new_ref))
|
||||
|
||||
def set_zero_reference(self):
|
||||
self.motion_joint_positions = [ENCODER_STANDING_REF.copy()]
|
||||
self.motion_joint_velocities = [np.zeros(29, np.float32)]
|
||||
self.motion_body_quats = [np.array([1,0,0,0], np.float32)]
|
||||
self.motion_body_z = [DEFAULT_HEIGHT]
|
||||
self.motion_timesteps = 1
|
||||
self.freeze_ref_frame = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
|
||||
def build_encoder_obs(self):
|
||||
obs = np.zeros(1762, np.float32)
|
||||
obs[0] = float(self.encode_mode)
|
||||
rf = min(self.freeze_ref_frame, self.motion_timesteps - 1)
|
||||
ref_pos, ref_quat = self.motion_joint_positions[rf], self.motion_body_quats[rf]
|
||||
if self.encode_mode == 0:
|
||||
for f in range(10):
|
||||
obs[4+29*f:4+29*(f+1)] = ref_pos
|
||||
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
elif self.encode_mode == 1:
|
||||
ref_lower = ref_pos[LOWER_BODY_IL]
|
||||
for f in range(10):
|
||||
obs[661+12*f:661+12*(f+1)] = ref_lower
|
||||
obs[901:910] = self.vr_3point_local_target
|
||||
obs[910:922] = self.vr_3point_local_orn_target
|
||||
obs[595:601] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
elif self.encode_mode == 2:
|
||||
obs[922:1642] = self.smpl_joints_10frame_step1
|
||||
for f in range(10):
|
||||
obs[1642+6*f:1642+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
obs[1702+6*f:1702+6*(f+1)] = ref_pos[WRIST_IL]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported encoder mode: {self.encode_mode}")
|
||||
return obs
|
||||
|
||||
def build_decoder_obs(self):
|
||||
obs = np.zeros(994, np.float32); off = 0
|
||||
obs[off:off+64] = self.token; off += 64
|
||||
for h, sz in [(list(reversed(self.h_ang)),3), (list(reversed(self.h_q_mj)),29),
|
||||
(list(reversed(self.h_dq_mj)),29), (list(reversed(self.h_act_mj)),29)]:
|
||||
for f in range(10): obs[off:off+sz] = h[f]; off += sz
|
||||
for q in reversed(self.h_quat):
|
||||
obs[off:off+3] = gravity_dir(q); off += 3
|
||||
assert off == 994, f"Decoder obs mismatch: {off}"
|
||||
return obs
|
||||
|
||||
def run_encoder(self):
|
||||
return self.encoder.run(None, {self.encoder_input: self.build_encoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
|
||||
|
||||
def step(self, robot_obs, update_encoder, debug=False):
|
||||
jnames = [m.name for m in G1_29_JointIndex]
|
||||
q = np.array([robot_obs.get(f"{n}.q", DEFAULT_ANGLES[m.value]) for m,n in zip(G1_29_JointIndex,jnames)], np.float32)
|
||||
dq = np.array([robot_obs.get(f"{n}.dq", 0.0) for n in jnames], np.float32)
|
||||
quat = np.array([robot_obs.get("imu.quat.w",1), robot_obs.get("imu.quat.x",0),
|
||||
robot_obs.get("imu.quat.y",0), robot_obs.get("imu.quat.z",0)], np.float32)
|
||||
ang = np.array([robot_obs.get(f"imu.gyro.{a}",0) for a in "xyz"], np.float32)
|
||||
self.update_history(q, dq, ang, quat)
|
||||
if update_encoder: self.token = self.run_encoder()
|
||||
action_mj = self.decoder.run(None, {self.decoder_input: self.build_decoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
|
||||
self.last_action_mj = action_mj.copy()
|
||||
target = DEFAULT_ANGLES + action_mj[ISAACLAB_TO_MUJOCO] * ACTION_SCALE
|
||||
if debug:
|
||||
delta = target - q
|
||||
print(f"token_norm={np.linalg.norm(self.token):.4f} action_norm={np.linalg.norm(action_mj):.4f} "
|
||||
f"delta_max={np.max(np.abs(delta)):.4f} delta_rms={np.sqrt(np.mean(delta**2)):.4f}")
|
||||
return {f"{m.name}.q": float(target[m.value]) for m in G1_29_JointIndex}
|
||||
|
||||
def print_input_diagnostics(self):
|
||||
print("\n[Diag] Standing reference checks")
|
||||
names = {0:"g1", 1:"teleop", 2:"smpl"}
|
||||
print(f" encoder mode: {self.encode_mode} ({names.get(self.encode_mode,'unknown')})")
|
||||
print(f" DEFAULT_ANGLES range: [{DEFAULT_ANGLES.min():+.4f}, {DEFAULT_ANGLES.max():+.4f}]")
|
||||
print(f" anchor_6d(identity): {self._anchor_6d(np.array([1,0,0,0],np.float32), np.array([1,0,0,0],np.float32))}")
|
||||
print(f" gravity(identity): {gravity_dir(np.array([1,0,0,0],np.float32))} (expect [0,0,-1])")
|
||||
dec0 = self.build_decoder_obs()
|
||||
print(f" decoder q-delta max: {np.max(np.abs(dec0[94:384])):.6f}")
|
||||
print(f" decoder dq max: {np.max(np.abs(dec0[384:674])):.6f}")
|
||||
|
||||
# ── Planner motion buffer ─────────────────────────────────────────────────────
|
||||
|
||||
class PlannerMotion:
|
||||
def __init__(self, max_frames=1500):
|
||||
self.timesteps = 0
|
||||
self.joint_positions = np.zeros((max_frames, 29), np.float64)
|
||||
self.joint_velocities = np.zeros((max_frames, 29), np.float64)
|
||||
self.body_positions = np.zeros((max_frames, 3), np.float64)
|
||||
self.body_quaternions = np.zeros((max_frames, 4), np.float64)
|
||||
self.body_quaternions[:, 0] = 1.0
|
||||
|
||||
# ── Subprocess planner ────────────────────────────────────────────────────────
|
||||
|
||||
def _resample_30_to_50(qpos, n30):
|
||||
t50 = int(np.floor(n30 / 30.0 * 50))
|
||||
f30 = np.arange(t50) / 50.0 * 30.0
|
||||
f0 = np.floor(f30).astype(int)
|
||||
f1 = np.minimum(f0+1, n30-1)
|
||||
frac, w0 = (f30-f0).astype(np.float64), None
|
||||
w0 = 1.0 - frac
|
||||
jp = (w0[:,None]*qpos[f0,7:36] + frac[:,None]*qpos[f1,7:36])[:,MUJOCO_TO_ISAACLAB]
|
||||
jv = np.zeros_like(jp)
|
||||
if t50 >= 2: jv[:t50-1] = (jp[1:] - jp[:-1]) * 50.0; jv[-1] = jv[-2]
|
||||
return {
|
||||
"timesteps": t50,
|
||||
"joint_positions": jp,
|
||||
"joint_velocities": jv,
|
||||
"body_positions": w0[:,None]*qpos[f0,:3] + frac[:,None]*qpos[f1,:3],
|
||||
"body_quaternions": quat_slerp_batch(qpos[f0,3:7], qpos[f1,3:7], frac),
|
||||
}
|
||||
|
||||
def _build_planner_inputs(ctx, ms_dict, version, seed):
|
||||
inp = {
|
||||
"context_mujoco_qpos": ctx.astype(np.float32).reshape(1,4,36),
|
||||
"target_vel": np.array([ms_dict["speed"]], np.float32),
|
||||
"mode": np.array([ms_dict["mode"]], np.int64),
|
||||
"movement_direction": np.array(ms_dict["movement_direction"], np.float32).reshape(1,3),
|
||||
"facing_direction": np.array(ms_dict["facing_direction"], np.float32).reshape(1,3),
|
||||
"random_seed": np.array([seed], np.int64),
|
||||
}
|
||||
if version >= 1:
|
||||
allowed = np.zeros((1,K), np.int64); allowed[0,:6] = 1
|
||||
inp.update({
|
||||
"height": np.array([ms_dict["height"]], np.float32),
|
||||
"has_specific_target": np.array([[0]], np.int64),
|
||||
"specific_target_positions": np.zeros((1,4,3), np.float32),
|
||||
"specific_target_headings": np.zeros((1,4), np.float32),
|
||||
"allowed_pred_num_tokens": allowed,
|
||||
})
|
||||
return inp
|
||||
|
||||
def _planner_worker(path, req_q, res_q, stop_evt, version, seed):
|
||||
so = ort.SessionOptions(); so.log_severity_level = 3
|
||||
sess = ort.InferenceSession(path, sess_options=so, providers=["CPUExecutionProvider"])
|
||||
while not stop_evt.is_set():
|
||||
try: ctx, gf, ms_dict = req_q.get(timeout=0.05)
|
||||
except Exception: continue
|
||||
try:
|
||||
inp = _build_planner_inputs(ctx, ms_dict, version, seed)
|
||||
t0 = time.time()
|
||||
qpos_out, num_pred = sess.run(None, inp)
|
||||
t_inf = time.time()
|
||||
n = int(num_pred.flat[0])
|
||||
qpos = qpos_out[0,:n]
|
||||
if np.any(np.isnan(qpos)): continue
|
||||
motion = _resample_30_to_50(qpos, n)
|
||||
motion["gen_frame"] = gf
|
||||
print(f"[Planner] inf={1000*(t_inf-t0):.1f}ms total={1000*(time.time()-t0):.1f}ms frames={n}", flush=True)
|
||||
while not res_q.empty():
|
||||
try: res_q.get_nowait()
|
||||
except Exception: break
|
||||
res_q.put(motion)
|
||||
except Exception as e:
|
||||
print(f"[Planner] Error: {e}", flush=True)
|
||||
|
||||
# ── SonicPlanner ──────────────────────────────────────────────────────────────
|
||||
|
||||
class SonicPlanner:
|
||||
def __init__(self, session, planner_path):
|
||||
self.session = session
|
||||
self.planner_path = planner_path
|
||||
self.gen_frame = 0
|
||||
self.random_seed = INITIAL_RANDOM_SEED
|
||||
self.version = 1 if len(session.get_inputs()) >= 11 else 0
|
||||
self.motion_50hz = PlannerMotion()
|
||||
self._snapshot = PlannerMotion()
|
||||
self._req_q = self._res_q = self._stop_evt = self._proc = None
|
||||
self._ctrl = None
|
||||
|
||||
def _build_inputs(self, ctx, ms):
|
||||
return _build_planner_inputs(
|
||||
ctx,
|
||||
{"mode": ms.mode, "speed": ms.speed, "height": ms.height,
|
||||
"movement_direction": list(ms.movement_direction),
|
||||
"facing_direction": list(ms.facing_direction)},
|
||||
self.version, self.random_seed)
|
||||
|
||||
@staticmethod
|
||||
def build_initial_context(joint_positions):
|
||||
ctx = np.zeros((4,36), np.float32)
|
||||
for n in range(4):
|
||||
ctx[n,2] = DEFAULT_HEIGHT; ctx[n,3] = 1.0
|
||||
ctx[n,7:36] = joint_positions.astype(np.float32)
|
||||
return ctx
|
||||
|
||||
def _context_from_controller(self, current_frame):
|
||||
ctrl = self._ctrl
|
||||
gen_frame = current_frame + MOTION_LOOK_AHEAD_STEPS
|
||||
t_arr = gen_frame/50.0 + np.arange(4)/30.0
|
||||
f50 = t_arr * 50.0
|
||||
with ctrl.motion_lock:
|
||||
ts = ctrl.motion_timesteps
|
||||
bp = ctrl.motion_body_pos[:ts].copy()
|
||||
bq = ctrl.motion_body_quats[:ts].copy()
|
||||
jp = ctrl.motion_joint_positions[:ts].copy()
|
||||
f0 = np.minimum(np.floor(f50).astype(int), ts-1)
|
||||
f1 = np.minimum(f0+1, ts-1)
|
||||
frac, w0 = f50-f0, None; w0 = 1.0-frac
|
||||
ctx = np.zeros((4,36), np.float32)
|
||||
ctx[:,0:3] = w0[:,None]*bp[f0] + frac[:,None]*bp[f1]
|
||||
ctx[:,3:7] = quat_slerp_batch(bq[f0], bq[f1], frac)
|
||||
ij = w0[:,None]*jp[f0] + frac[:,None]*jp[f1]
|
||||
ctx[:,7:36] = ij[:,ISAACLAB_TO_MUJOCO]
|
||||
self.gen_frame = gen_frame
|
||||
return ctx
|
||||
|
||||
def _load_motion_in_place(self, qpos, n30, target=None):
|
||||
if target is None: target = self.motion_50hz
|
||||
r = _resample_30_to_50(qpos, n30)
|
||||
n = r["timesteps"]; target.timesteps = n
|
||||
target.joint_positions[:n] = r["joint_positions"]
|
||||
target.joint_velocities[:n] = r["joint_velocities"]
|
||||
target.body_positions[:n] = r["body_positions"]
|
||||
target.body_quaternions[:n] = r["body_quaternions"]
|
||||
return target
|
||||
|
||||
def initialize(self, joint_positions, ms):
|
||||
ctx = self.build_initial_context(joint_positions)
|
||||
qpos_out, num_pred = self.session.run(None, self._build_inputs(ctx, ms))
|
||||
n = int(num_pred.flat[0]); qpos = qpos_out[0,:n]
|
||||
if np.any(np.isnan(qpos)): raise RuntimeError("Planner initial output contains NaN")
|
||||
print(f"[Planner] Init: {n} frames @ 30 Hz")
|
||||
self._load_motion_in_place(qpos, n)
|
||||
print(f"[Planner] Resampled to {self.motion_50hz.timesteps} frames @ 50 Hz")
|
||||
return self.motion_50hz
|
||||
|
||||
def request_replan(self, cursor, ms):
|
||||
if self._req_q is None: return
|
||||
ctx = self._context_from_controller(cursor)
|
||||
ms_dict = {"mode": ms.mode, "speed": ms.speed, "height": ms.height,
|
||||
"movement_direction": list(ms.movement_direction),
|
||||
"facing_direction": list(ms.facing_direction)}
|
||||
while not self._req_q.empty():
|
||||
try: self._req_q.get_nowait()
|
||||
except Exception: break
|
||||
self._req_q.put((ctx, self.gen_frame, ms_dict))
|
||||
|
||||
def try_get_new_motion(self):
|
||||
if self._res_q is None: return None
|
||||
result = None
|
||||
while not self._res_q.empty():
|
||||
try: result = self._res_q.get_nowait()
|
||||
except Exception: break
|
||||
if result is None: return None
|
||||
n, gf = result["timesteps"], result["gen_frame"]
|
||||
s = self._snapshot; s.timesteps = n
|
||||
s.joint_positions[:n] = result["joint_positions"]
|
||||
s.joint_velocities[:n] = result["joint_velocities"]
|
||||
s.body_positions[:n] = result["body_positions"]
|
||||
s.body_quaternions[:n] = result["body_quaternions"]
|
||||
return s, gf
|
||||
|
||||
def start_subprocess(self, controller):
|
||||
self._ctrl = controller
|
||||
self._req_q, self._res_q, self._stop_evt = mp.Queue(), mp.Queue(), mp.Event()
|
||||
self._proc = mp.Process(
|
||||
target=_planner_worker,
|
||||
args=(self.planner_path, self._req_q, self._res_q,
|
||||
self._stop_evt, self.version, self.random_seed),
|
||||
daemon=True)
|
||||
self._proc.start()
|
||||
print(f"[Planner] Background process started (PID={self._proc.pid})")
|
||||
|
||||
def stop_subprocess(self):
|
||||
if self._stop_evt: self._stop_evt.set()
|
||||
if self._proc:
|
||||
self._proc.join(timeout=3.0)
|
||||
if self._proc.is_alive(): self._proc.terminate()
|
||||
print("[Planner] Background process stopped")
|
||||
for q in (self._req_q, self._res_q):
|
||||
if q: q.close()
|
||||
|
||||
# ── PlannerController ─────────────────────────────────────────────────────────
|
||||
|
||||
class PlannerController(StandingEncoderDecoder):
|
||||
def __init__(self, planner, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
self.planner = planner
|
||||
self.ref_cursor = 0
|
||||
self.motion_timesteps = 0
|
||||
self.motion_joint_positions = np.zeros((1500,29), np.float64)
|
||||
self.motion_joint_velocities = np.zeros((1500,29), np.float64)
|
||||
self.motion_body_quats = np.zeros((1500,4), np.float64); self.motion_body_quats[:,0] = 1.0
|
||||
self.motion_body_pos = np.zeros((1500,3), np.float64)
|
||||
self.init_ref_quat = np.array([1,0,0,0], np.float64)
|
||||
self.heading_init_base_quat = np.array([1,0,0,0], np.float64)
|
||||
self.delta_heading = 0.0
|
||||
self.reinit_heading = False
|
||||
self.playing = self.first_motion = False
|
||||
self.motion_lock = threading.Lock()
|
||||
|
||||
def load_initial_motion(self, motion):
|
||||
with self.motion_lock:
|
||||
n = motion.timesteps
|
||||
self.motion_timesteps = n
|
||||
self.motion_joint_positions[:n] = motion.joint_positions[:n]
|
||||
self.motion_joint_velocities[:n] = motion.joint_velocities[:n]
|
||||
self.motion_body_quats[:n] = motion.body_quaternions[:n]
|
||||
self.motion_body_pos[:n] = motion.body_positions[:n]
|
||||
self.init_ref_quat = motion.body_quaternions[0].copy()
|
||||
self.ref_cursor = 0; self.first_motion = True
|
||||
self.playing = True; self.delta_heading = 0.0
|
||||
|
||||
def blend_new_motion(self, new_motion, gen_frame):
|
||||
with self.motion_lock:
|
||||
cur = self.ref_cursor
|
||||
new_len = gen_frame - cur + new_motion.timesteps
|
||||
if new_len <= 0: return
|
||||
f_arr = np.arange(new_len)
|
||||
f_old = np.minimum(f_arr + cur, self.motion_timesteps - 1)
|
||||
f_new = np.clip(f_arr + cur - gen_frame, 0, new_motion.timesteps - 1)
|
||||
blend_start = max(0, gen_frame - cur)
|
||||
w_new = np.clip((f_arr - blend_start) / BLEND_FRAMES if BLEND_FRAMES > 0
|
||||
else np.ones(new_len), 0.0, 1.0)
|
||||
w_old = 1.0 - w_new
|
||||
self.motion_joint_positions[:new_len] = w_old[:,None]*self.motion_joint_positions[f_old] + w_new[:,None]*new_motion.joint_positions[f_new]
|
||||
self.motion_joint_velocities[:new_len] = w_old[:,None]*self.motion_joint_velocities[f_old] + w_new[:,None]*new_motion.joint_velocities[f_new]
|
||||
self.motion_body_pos[:new_len] = w_old[:,None]*self.motion_body_pos[f_old] + w_new[:,None]*new_motion.body_positions[f_new]
|
||||
self.motion_body_quats[:new_len] = quat_slerp_batch(self.motion_body_quats[f_old], new_motion.body_quaternions[f_new], w_new)
|
||||
self.motion_timesteps = new_len; self.first_motion = False; self.ref_cursor = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
|
||||
def _heading_apply_delta(self):
|
||||
delta = quat_mul(heading_quat(self.heading_init_base_quat).astype(np.float32),
|
||||
heading_quat_inv(self.init_ref_quat).astype(np.float32))
|
||||
if self.delta_heading:
|
||||
h = self.delta_heading / 2.0
|
||||
delta = quat_mul(np.array([np.cos(h),0,0,np.sin(h)], np.float32), delta)
|
||||
return delta
|
||||
|
||||
def _anchor_6d(self, base_quat, ref_quat=None):
|
||||
if ref_quat is None: ref_quat = self.init_ref_quat
|
||||
new_ref = quat_mul(self._heading_apply_delta(), ref_quat.astype(np.float32))
|
||||
return quat_to_6d(quat_mul(quat_conj(base_quat.astype(np.float32)), new_ref))
|
||||
|
||||
def build_encoder_obs(self):
|
||||
obs = np.zeros(1762, np.float32); obs[0] = float(self.encode_mode)
|
||||
with self.motion_lock:
|
||||
for f in range(10):
|
||||
tf = min(self.ref_cursor + f*5 if self.playing else self.ref_cursor,
|
||||
self.motion_timesteps - 1)
|
||||
obs[4+29*f:4+29*(f+1)] = self.motion_joint_positions[tf].astype(np.float32)
|
||||
if self.playing:
|
||||
obs[294+29*f:294+29*(f+1)] = self.motion_joint_velocities[tf].astype(np.float32)
|
||||
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(
|
||||
self.h_quat[0], self.motion_body_quats[tf].astype(np.float32))
|
||||
return obs
|
||||
|
||||
def step(self, robot_obs, update_encoder, debug=False):
|
||||
if robot_obs and (self.first_motion or self.reinit_heading):
|
||||
q = robot_obs.get("imu.quaternion")
|
||||
if q is not None:
|
||||
self.heading_init_base_quat = np.array(q, np.float64)
|
||||
with self.motion_lock:
|
||||
rf = min(self.ref_cursor, self.motion_timesteps - 1)
|
||||
self.init_ref_quat = self.motion_body_quats[rf].copy()
|
||||
self.delta_heading = 0.0
|
||||
self.first_motion = False
|
||||
self.reinit_heading = False
|
||||
print(f"[Heading] init quat: {self.heading_init_base_quat}")
|
||||
return super().step(robot_obs, update_encoder=update_encoder, debug=debug)
|
||||
|
||||
def advance_cursor(self, wall_dt):
|
||||
if not self.playing: return
|
||||
frames = max(1, round(wall_dt / CONTROL_DT))
|
||||
with self.motion_lock:
|
||||
self.ref_cursor = min(self.ref_cursor + frames, self.motion_timesteps - 1)
|
||||
|
||||
# ── Keyboard ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class RawKeyboard:
|
||||
def __init__(self):
|
||||
self.fd = sys.stdin.fileno()
|
||||
self.old = termios.tcgetattr(self.fd)
|
||||
def __enter__(self): tty.setcbreak(self.fd); return self
|
||||
def __exit__(self, *_): termios.tcsetattr(self.fd, termios.TCSADRAIN, self.old)
|
||||
def get_key(self):
|
||||
return sys.stdin.read(1) if select.select([sys.stdin],[],[],0)[0] else None
|
||||
|
||||
def process_keyboard(key, ms, controller=None):
|
||||
if key is None: return False
|
||||
if key == '\x1b': return True
|
||||
if key == ' ':
|
||||
ms.mode = LM.IDLE; ms.speed = ms.height = -1.0
|
||||
ms.has_movement = False; ms.needs_replan = True
|
||||
if controller: controller.playing = False; controller.reinit_heading = True
|
||||
print("\n >> EMERGENCY STOP -> IDLE"); return False
|
||||
if key in ('r','R'):
|
||||
ms.needs_replan = True; print("\n >> Manual replan"); return False
|
||||
if key in ('n','N','p','P'):
|
||||
ms.motion_set_idx = (ms.motion_set_idx + (1 if key in ('n','N') else -1)) % len(MOTION_SETS)
|
||||
name, modes = MOTION_SETS[ms.motion_set_idx]
|
||||
print(f"\n >> Motion set: {name}")
|
||||
[print(f" {i+1}: {m.name}") for i,m in enumerate(modes)]
|
||||
return False
|
||||
if key.isdigit() and key not in ('9','0'):
|
||||
idx = int(key) - 1; modes = MOTION_SETS[ms.motion_set_idx][1]
|
||||
if 0 <= idx < len(modes):
|
||||
ms.mode = modes[idx]; ms.needs_replan = True
|
||||
if controller: controller.playing = True; controller.reinit_heading = True
|
||||
print(f"\n >> Mode: {LM(ms.mode).name} ({ms.mode}) [replanning...]")
|
||||
return False
|
||||
if key == '9':
|
||||
ms.speed = max(0.0, (ms.speed if ms.speed>=0 else 1.0) - 0.1)
|
||||
print(f"\n >> Speed: {ms.speed:.1f}"); return False
|
||||
if key == '0':
|
||||
ms.speed = min(5.0, (ms.speed if ms.speed>=0 else 1.0) + 0.1)
|
||||
print(f"\n >> Speed: {ms.speed:.1f}"); return False
|
||||
if key == '-':
|
||||
ms.height = max(0.2, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) - 0.02)
|
||||
print(f"\n >> Height: {ms.height:.2f}"); return False
|
||||
if key == '=':
|
||||
ms.height = min(1.0, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) + 0.02)
|
||||
print(f"\n >> Height: {ms.height:.2f}"); return False
|
||||
if key.lower() == 'w': ms.movement_angle = ms.facing_angle
|
||||
elif key.lower() == 's': ms.movement_angle = ms.facing_angle + math.pi
|
||||
elif key.lower() == 'a': ms.movement_angle = ms.facing_angle + math.pi/2
|
||||
elif key.lower() == 'd': ms.movement_angle = ms.facing_angle - math.pi/2
|
||||
if key.lower() in ('w','s','a','d'):
|
||||
ms.has_movement = ms.needs_replan = True
|
||||
elif key.lower() == 'q':
|
||||
ms.facing_angle += 0.1
|
||||
if controller: controller.delta_heading += 0.1
|
||||
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
|
||||
elif key.lower() == 'e':
|
||||
ms.facing_angle -= 0.1
|
||||
if controller: controller.delta_heading -= 0.1
|
||||
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
|
||||
return False
|
||||
|
||||
_joy_prev_active = False
|
||||
|
||||
|
||||
def _parse_wireless(wr):
|
||||
"""Parse wireless_remote (bytes or int-array) into (lx, ly, rx, ry)."""
|
||||
import struct as _st
|
||||
if not isinstance(wr, (bytes, bytearray)):
|
||||
wr = bytes(wr)
|
||||
if len(wr) < 24:
|
||||
return None
|
||||
lx = _st.unpack("f", wr[4:8])[0]
|
||||
rx = _st.unpack("f", wr[8:12])[0]
|
||||
ry = _st.unpack("f", wr[12:16])[0]
|
||||
ly = _st.unpack("f", wr[20:24])[0]
|
||||
return lx, ly, rx, ry
|
||||
|
||||
|
||||
def process_joystick(obs, ms, controller=None):
|
||||
"""Joystick mirrors keyboard: left stick=WASD, right stick X=Q/E, right stick Y=height."""
|
||||
global _joy_prev_active
|
||||
wr = obs.get("wireless_remote")
|
||||
if wr is None:
|
||||
return
|
||||
parsed = _parse_wireless(wr)
|
||||
if parsed is None:
|
||||
return
|
||||
lx, ly, rx, ry = parsed
|
||||
|
||||
# Dead zone + negate both Y axes (bridge already flips them once)
|
||||
lx = 0.0 if abs(lx) < DEADZONE else lx
|
||||
ly = 0.0 if abs(ly) < DEADZONE else -ly
|
||||
rx = 0.0 if abs(rx) < DEADZONE else rx
|
||||
ry = 0.0 if abs(ry) < DEADZONE else -ry
|
||||
|
||||
left_active = abs(lx) > 0 or abs(ly) > 0
|
||||
|
||||
# Left stick → WASD (movement direction relative to facing)
|
||||
if left_active:
|
||||
ms.movement_angle = ms.facing_angle + math.atan2(-lx, -ly)
|
||||
ms.has_movement = True
|
||||
if not _joy_prev_active:
|
||||
ms.needs_replan = True
|
||||
_joy_prev_active = True
|
||||
elif _joy_prev_active and not (abs(rx) > 0 or abs(ry) > 0):
|
||||
_joy_prev_active = False
|
||||
ms.has_movement = False
|
||||
|
||||
# Right stick X → Q/E (facing rotation, ~1 rad/s at full deflection)
|
||||
if abs(rx) > 0:
|
||||
delta = -0.02 * rx
|
||||
ms.facing_angle += delta
|
||||
if controller:
|
||||
controller.delta_heading += delta
|
||||
|
||||
# Right stick Y → -/= (height adjustment, ~0.25/s at full deflection)
|
||||
if abs(ry) > 0:
|
||||
step = -0.005 * ry
|
||||
ms.height = max(0.1, min(1.0, (ms.height if ms.height >= 0 else DEFAULT_HEIGHT) + step))
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SONIC planner with keyboard + gamepad control")
|
||||
parser.add_argument("--ip", type=str, default=None,
|
||||
help="Robot IP for real hardware (e.g. 192.168.123.164). "
|
||||
"Omit for simulation.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("SONIC planner - full mode control")
|
||||
print(" N/P cycle sets | 1-8 select mode | WASD move")
|
||||
print(" Q/E rotate | 9/0 speed | -/= height")
|
||||
print(" R replan | Space IDLE | Esc quit")
|
||||
if args.ip:
|
||||
print(f" Robot IP: {args.ip}")
|
||||
else:
|
||||
print(" Mode: simulation")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
planner_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="planner_sonic.onnx")
|
||||
encoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_encoder.onnx")
|
||||
decoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_decoder.onnx")
|
||||
|
||||
providers = ort.get_available_providers()
|
||||
use_gpu = "CUDAExecutionProvider" in providers
|
||||
gpu_ep = (["CUDAExecutionProvider","CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"])
|
||||
so = ort.SessionOptions(); so.log_severity_level = 3
|
||||
|
||||
print(f"[ONNX] enc/dec={'GPU' if use_gpu else 'CPU'}, planner=CPU")
|
||||
planner_sess = ort.InferenceSession(planner_path, sess_options=so, providers=["CPUExecutionProvider"])
|
||||
encoder_sess = ort.InferenceSession(encoder_path, sess_options=so, providers=gpu_ep)
|
||||
decoder_sess = ort.InferenceSession(decoder_path, sess_options=so, providers=gpu_ep)
|
||||
print(f"[Planner] version={'v1+' if len(planner_sess.get_inputs())>=11 else 'v0'}")
|
||||
|
||||
cfg = UnitreeG1Config()
|
||||
if args.ip:
|
||||
cfg.is_simulation = False
|
||||
cfg.robot_ip = args.ip
|
||||
robot = UnitreeG1(cfg); robot.connect()
|
||||
kp, kd = _kp_kd(); robot.kp = kp.copy(); robot.kd = kd.copy()
|
||||
|
||||
ms = MovementState()
|
||||
planner = SonicPlanner(planner_sess, planner_path)
|
||||
controller = PlannerController(planner, encoder_sess, decoder_sess)
|
||||
|
||||
motion = planner.initialize(DEFAULT_ANGLES, ms)
|
||||
controller.load_initial_motion(motion)
|
||||
controller.print_input_diagnostics()
|
||||
planner.start_subprocess(controller)
|
||||
|
||||
print(f"\nStarting: {MOTION_SETS[0][0]}")
|
||||
[print(f" {i+1}: {m.name}") for i,m in enumerate(MOTION_SETS[0][1])]
|
||||
|
||||
with RawKeyboard() as kb:
|
||||
try:
|
||||
gc.disable(); gc_timer = 0.0
|
||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES); time.sleep(1.0)
|
||||
|
||||
step = 0; last_status = replan_timer = 0.0
|
||||
loop_t = enc_t = dec_t = obs_t = act_t = []
|
||||
slow_n = blend_n = 0; stall_src = ""; did_blend = False
|
||||
prev_end = time.time(); t_start = time.time()
|
||||
|
||||
log_path = "/tmp/sonic_pose_log.csv"
|
||||
jnames = [m.name for m in G1_29_JointIndex]
|
||||
with open(log_path, "w") as log_f:
|
||||
log_f.write("t,step,cursor,ts,blend,mode," +
|
||||
",".join(f"q{i}" for i in range(29)) + "," +
|
||||
",".join(f"ref{i}" for i in range(29)) + "," +
|
||||
",".join(f"act{i}" for i in range(29)) +
|
||||
",delta_max,action_norm,token_norm\n")
|
||||
|
||||
while not robot._shutdown_event.is_set():
|
||||
t0 = time.time()
|
||||
if process_keyboard(kb.get_key(), ms, controller): break
|
||||
|
||||
obs = robot.get_observation(); t_obs = time.time()
|
||||
obs_t.append(1000*(t_obs - t0))
|
||||
if not obs:
|
||||
step += 1; prev_end = time.time()
|
||||
time.sleep(max(0.0, CONTROL_DT-(time.time()-t0))); continue
|
||||
|
||||
process_joystick(obs, ms, controller)
|
||||
clamp_mode_params(ms)
|
||||
|
||||
is_static = LM(ms.mode) in STATIC_MODES
|
||||
do_req = ms.needs_replan and step > 0
|
||||
if do_req: ms.needs_replan = False; replan_timer = 0.0
|
||||
elif not is_static and step > 0 and ms.speed != 0:
|
||||
replan_timer += CONTROL_DT
|
||||
if replan_timer >= replan_interval(ms.mode):
|
||||
do_req = True; replan_timer = 0.0
|
||||
if do_req: planner.request_replan(controller.ref_cursor, ms)
|
||||
|
||||
do_enc = (step % ENCODER_UPDATE_EVERY == 0)
|
||||
t_step = time.time()
|
||||
action = controller.step(obs, update_encoder=do_enc, debug=(step % DEBUG_PRINT_EVERY == 0))
|
||||
step_ms = 1000*(time.time()-t_step)
|
||||
(enc_t if do_enc else dec_t).append(step_ms)
|
||||
|
||||
t_act = time.time()
|
||||
robot.send_action(action)
|
||||
act_t.append(1000*(time.time()-t_act))
|
||||
|
||||
result = planner.try_get_new_motion()
|
||||
t_blend = time.time()
|
||||
if result:
|
||||
controller.blend_new_motion(*result)
|
||||
blend_ms = 1000*(time.time()-t_blend)
|
||||
blend_n += 1; did_blend = True
|
||||
else:
|
||||
blend_ms = 0.0
|
||||
|
||||
if step % 5 == 0:
|
||||
t_rel = time.time() - t_start
|
||||
q_r = np.array([obs.get(f"{n}.q", 0) for n in jnames])
|
||||
a_v = np.array([action.get(f"{n}.q", 0) for n in jnames])
|
||||
cur, ts = controller.ref_cursor, controller.motion_timesteps
|
||||
q_ref = controller.motion_joint_positions[min(cur,ts-1)] if ts > 0 else np.zeros(29)
|
||||
log_f.write(f"{t_rel:.4f},{step},{cur},{ts},{int(did_blend)},{ms.mode}," +
|
||||
",".join(f"{v:.6f}" for v in q_r) + "," +
|
||||
",".join(f"{v:.6f}" for v in q_ref) + "," +
|
||||
",".join(f"{v:.6f}" for v in a_v) + "," +
|
||||
f"{np.max(np.abs(a_v-q_r)):.6f},"
|
||||
f"{np.linalg.norm(a_v):.6f},"
|
||||
f"{np.linalg.norm(controller.token):.6f}\n")
|
||||
did_blend = False
|
||||
|
||||
now = time.time(); loop_ms = 1000*(now-t0)
|
||||
wall_dt = now - prev_end; loop_t.append(loop_ms)
|
||||
if loop_ms > 50:
|
||||
stall_src = (f"[STALL] {loop_ms:.0f}ms: "
|
||||
f"obs={obs_t[-1]:.0f} blend={blend_ms:.0f} step={step_ms:.0f} act={act_t[-1]:.0f}")
|
||||
if loop_ms > CONTROL_DT*1500: slow_n += 1
|
||||
|
||||
controller.advance_cursor(wall_dt)
|
||||
|
||||
if now - last_status > 2.0:
|
||||
def _avg(l): return sum(l)/len(l) if l else 0
|
||||
hz = 1000/_avg(loop_t) if _avg(loop_t) else 0
|
||||
print(f"\r {ms.status_line()} step={step} ref={controller.ref_cursor}/{controller.motion_timesteps} "
|
||||
f"loop={_avg(loop_t):.1f}ms(max={max(loop_t,default=0):.1f}) hz={hz:.0f} "
|
||||
f"enc={_avg(enc_t):.1f} dec={_avg(dec_t):.1f} obs={_avg(obs_t):.1f} "
|
||||
f"slow={slow_n} blends={blend_n}", end="", flush=True)
|
||||
if stall_src: print(f"\n {stall_src}"); stall_src = ""
|
||||
last_status = now
|
||||
loop_t=enc_t=dec_t=obs_t=act_t=[]; slow_n=blend_n=0
|
||||
|
||||
prev_end = time.time()
|
||||
gc_timer += CONTROL_DT
|
||||
if gc_timer >= 10.0: gc.collect(); gc_timer = 0.0
|
||||
step += 1
|
||||
time.sleep(max(0.0, CONTROL_DT-(time.time()-t0)))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"\n[Log] Saved to {log_path}")
|
||||
planner.stop_subprocess()
|
||||
print("\nStopping...")
|
||||
if robot.is_connected: robot.disconnect()
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
|
||||
@@ -36,6 +36,16 @@ class DatasetConfig:
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}"
|
||||
)
|
||||
if len(self.episodes) != len(set(self.episodes)):
|
||||
duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1})
|
||||
raise ValueError(f"Episode indices contain duplicates: {duplicates}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
|
||||
@@ -746,7 +746,8 @@ def save_annotations_to_dataset(
|
||||
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
|
||||
):
|
||||
"""Save annotations to LeRobot dataset parquet format."""
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
|
||||
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||
@@ -840,7 +841,7 @@ def generate_auto_sparse_annotations(
|
||||
|
||||
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
|
||||
"""Load annotations from LeRobot dataset parquet files."""
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||
|
||||
@@ -24,7 +24,16 @@ import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -32,14 +41,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import packaging.version
|
||||
|
||||
V30_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v3.0 which is not backward compatible with v2.1.
|
||||
Please, update your dataset to the new format using this command:
|
||||
```
|
||||
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you already have a converted version uploaded to the hub, then this error might be because of
|
||||
an older version in your local cache. Consider deleting the cached version and retrying.
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
"""
|
||||
|
||||
|
||||
class CompatibilityError(Exception): ...
|
||||
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
if version.major == 2 and version.minor == 1:
|
||||
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ForwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
from lerobot.datasets.io_utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
|
||||
@@ -0,0 +1,517 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_file_size_in_mb,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
if task in self.tasks.index:
|
||||
return int(self.tasks.loc[task].task_index)
|
||||
else:
|
||||
return None
|
||||
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
if len(set(tasks)) != len(tasks):
|
||||
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||
self.tasks.loc[task] = task_idx
|
||||
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
||||
|
||||
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
||||
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
||||
at once instead of one row at a time.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert to list format for each value
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.episodes is not None and len(self.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
||||
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
||||
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
||||
episode_dict["dataset_from_index"] = [latest_num_frames]
|
||||
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
else:
|
||||
episode_dict["dataset_from_index"] = [0]
|
||||
episode_dict["dataset_to_index"] = [num_frames]
|
||||
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
else:
|
||||
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
||||
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
||||
latest_num_frames = self.latest_episode["episode_index"][0]
|
||||
|
||||
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, flush buffer and prepare new parquet file
|
||||
self._flush_metadata_buffer()
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
self._close_writer()
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
# Add to buffer
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
episode_dict.update(episode_metadata)
|
||||
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
This allows users to customize storage organization without modifying the constructor.
|
||||
These settings control how episodes are chunked and how large files can grow before
|
||||
creating new ones.
|
||||
|
||||
Args:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info["chunks_size"] = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
def get_chunk_settings(self) -> dict[str, int]:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
f"{self.__class__.__name__}({{\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Total episodes: '{self.total_episodes}',\n"
|
||||
f" Total frames: '{self.total_frames}',\n"
|
||||
f" Features: '{feature_keys}',\n"
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
robot_type: str | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
metadata_buffer_size: int = 10,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION,
|
||||
fps,
|
||||
features,
|
||||
use_videos,
|
||||
robot_type,
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError(
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
@@ -38,19 +38,22 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
@@ -915,7 +918,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
|
||||
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
||||
"""
|
||||
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import embed_images
|
||||
|
||||
hf_features = get_hf_features_from_features(meta.features)
|
||||
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
|
||||
|
||||
@@ -20,11 +20,9 @@ import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
@@ -0,0 +1,552 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
|
||||
|
||||
Args:
|
||||
features (dict): A LeRobot-style feature dictionary.
|
||||
|
||||
Returns:
|
||||
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If a feature has an unsupported shape.
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
"""Validate that feature names do not contain invalid characters.
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot features dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If any feature name contains '/'.
|
||||
"""
|
||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||
if invalid_features:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||
|
||||
|
||||
def hw_to_dataset_features(
|
||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
||||
|
||||
This function takes a dictionary describing hardware outputs (like joint states
|
||||
or camera image shapes) and formats it into the standard LeRobot feature
|
||||
specification.
|
||||
|
||||
Args:
|
||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||
joints) or shape (tuple for images).
|
||||
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
||||
or "action").
|
||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||
|
||||
Returns:
|
||||
dict: A LeRobot features dictionary.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {
|
||||
key: ftype
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == OBS_STR:
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
|
||||
def build_dataset_frame(
|
||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Construct a single data frame from raw values based on dataset features.
|
||||
|
||||
A "frame" is a dictionary containing all the data for a single timestep,
|
||||
formatted as numpy arrays according to the feature specification.
|
||||
|
||||
Args:
|
||||
ds_features (dict): The LeRobot dataset features dictionary.
|
||||
values (dict): A dictionary of raw values from the hardware/environment.
|
||||
prefix (str): The prefix to filter features by (e.g., "observation"
|
||||
or "action").
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representing a single frame of data.
|
||||
"""
|
||||
frame = {}
|
||||
for key, ft in ds_features.items():
|
||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||
continue
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert dataset features to policy features.
|
||||
|
||||
This function transforms the dataset's feature specification into a format
|
||||
that a policy can use, classifying features by type (e.g., visual, state,
|
||||
action) and ensuring correct shapes (e.g., channel-first for images).
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot dataset features dictionary.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If an image feature does not have a 3D shape.
|
||||
"""
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith(ACTION):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
||||
|
||||
Args:
|
||||
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
||||
|
||||
Returns:
|
||||
dict: A single merged feature dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a dtype mismatch for a feature being merged.
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
fps (int): The frames per second of the data.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
|
||||
|
||||
This ensures that adding these delta timestamps to any existing timestamp in
|
||||
the dataset will result in a value that aligns with the dataset's frame rate.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary where values are lists of time
|
||||
deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
tolerance_s (float): The allowed tolerance in seconds.
|
||||
raise_value_error (bool): If True, raises an error on failure.
|
||||
|
||||
Returns:
|
||||
bool: True if all deltas are valid, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
"""Convert delta timestamps in seconds to delta indices in frames.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary of time deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of frame delta indices.
|
||||
"""
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict) -> None:
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
# task is a special required field that's not part of regular features
|
||||
if "task" not in actual_features:
|
||||
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
|
||||
|
||||
# Remove task from actual_features for regular feature validation
|
||||
actual_features_for_validation = actual_features - {"task"}
|
||||
|
||||
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
||||
|
||||
common_features = actual_features_for_validation & expected_features
|
||||
for name in common_features:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
|
||||
"""Check for missing or extra features in a frame.
|
||||
|
||||
Args:
|
||||
actual_features (set[str]): The set of feature names present in the frame.
|
||||
expected_features (set[str]): The set of feature names expected in the frame.
|
||||
|
||||
Returns:
|
||||
str: An error message string if there's a mismatch, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - expected_features
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
) -> str:
|
||||
"""Validate the dtype and shape of a single feature's value.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
feature (dict): The feature specification from the LeRobot features dictionary.
|
||||
value: The value of the feature to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the feature dtype is not supported for validation.
|
||||
"""
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
) -> str:
|
||||
"""Validate a feature that is expected to be a numpy array.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_dtype (str): The expected numpy dtype as a string.
|
||||
expected_shape (list[int]): The expected shape.
|
||||
value (np.ndarray): The numpy array to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
) -> str:
|
||||
"""Validate a feature that is expected to be an image or video frame.
|
||||
|
||||
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value (str): The value to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
Ensures the buffer has the required keys, contains at least one frame, and
|
||||
has features consistent with the dataset's specification.
|
||||
|
||||
Args:
|
||||
episode_buffer (dict): The buffer containing data for a single episode.
|
||||
total_episodes (int): The current total number of episodes in the dataset.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the buffer is invalid.
|
||||
NotImplementedError: If the episode index is manually set and doesn't match.
|
||||
"""
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
@@ -22,6 +23,8 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -31,7 +34,7 @@ def safe_stop_image_writer(func):
|
||||
dataset = kwargs.get("dataset")
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
logger.warning("Waiting for image writer to terminate...")
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
@@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
PIL.Image.Image object.
|
||||
|
||||
Side Effects:
|
||||
Prints an error message to the console if the image writing process
|
||||
fails for any reason.
|
||||
Logs an error message if the image writing process fails for any reason.
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
@@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
flatten_dict,
|
||||
serialize_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.utils.utils import SuppressProgressBars
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
total_uncompressed_size = 0
|
||||
for row_group in range(metadata.num_row_groups):
|
||||
rg_metadata = metadata.row_group(row_group)
|
||||
for column in range(rg_metadata.num_columns):
|
||||
col_metadata = rg_metadata.column(column)
|
||||
total_uncompressed_size += col_metadata.total_uncompressed_size
|
||||
return total_uncompressed_size / (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes // (1024**2)
|
||||
|
||||
|
||||
def load_nested_dataset(
|
||||
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
|
||||
) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
|
||||
Args:
|
||||
pq_dir: Directory containing parquet files
|
||||
features: Optional features schema to ensure consistent loading of complex types like images
|
||||
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_file_size_in_mb(file_path: Path) -> float:
|
||||
"""Get file size on disk in megabytes.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the file.
|
||||
"""
|
||||
file_size_bytes = file_path.stat().st_size
|
||||
return file_size_bytes / (1024**2)
|
||||
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Embed image bytes into the dataset table before saving to Parquet.
|
||||
|
||||
This function prepares a Hugging Face dataset for serialization by converting
|
||||
image objects into an embedded format that can be stored in Arrow/Parquet.
|
||||
|
||||
Args:
|
||||
dataset (datasets.Dataset): The input dataset, possibly containing image features.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: The dataset with images embedded in the table storage.
|
||||
"""
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
"""Load data from a JSON file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Any: The data loaded from the JSON file.
|
||||
"""
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
"""Write data to a JSON file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to write.
|
||||
fpath (Path): The path to the output JSON file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information dictionary.
|
||||
"""
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
"""Serialize and write dataset statistics to their standard file path.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
|
||||
"""Load dataset statistics and cast numerical values to numpy arrays.
|
||||
|
||||
Returns None if the stats file doesn't exist.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
A dictionary of statistics or None if the file is not found.
|
||||
"""
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
path = local_dir / DEFAULT_TASKS_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tasks.to_parquet(path)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
tasks.index.name = "task"
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures.
|
||||
|
||||
Args:
|
||||
episodes: HuggingFace Dataset containing episode metadata
|
||||
local_dir: Root directory where the dataset will be stored
|
||||
"""
|
||||
episode_size_mb = get_hf_dataset_size_in_mb(episodes)
|
||||
if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError(
|
||||
f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
|
||||
f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
|
||||
"This function only supports single-file episode metadata. "
|
||||
)
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||
# Select episode features/columns containing references to episode data and videos
|
||||
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||
# This is to speedup access to these data, instead of having to load episode stats.
|
||||
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||
return episodes
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Load an image from a file into a numpy array.
|
||||
|
||||
Args:
|
||||
fpath (str | Path): Path to the image file.
|
||||
dtype (np.dtype): The desired data type of the output array. If floating,
|
||||
pixels are scaled to [0, 1].
|
||||
channel_first (bool): If True, converts the image to (C, H, W) format.
|
||||
Otherwise, it remains in (H, W, C) format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image as a numpy array.
|
||||
"""
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||
|
||||
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
item (dict): Dictionary of items from a dataset.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
@@ -23,526 +23,52 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
INFO_PATH,
|
||||
_validate_feature_names,
|
||||
from lerobot.datasets.compute_stats import compute_episode_stats
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import (
|
||||
check_delta_timestamps,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_safe_version,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
)
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.io_utils import (
|
||||
embed_images,
|
||||
get_file_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
load_nested_dataset,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
create_lerobot_dataset_card,
|
||||
get_safe_version,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
if task in self.tasks.index:
|
||||
return int(self.tasks.loc[task].task_index)
|
||||
else:
|
||||
return None
|
||||
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
if len(set(tasks)) != len(tasks):
|
||||
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||
self.tasks.loc[task] = task_idx
|
||||
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
||||
|
||||
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
||||
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
||||
at once instead of one row at a time.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert to list format for each value
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.episodes is not None and len(self.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
||||
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
||||
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
||||
episode_dict["dataset_from_index"] = [latest_num_frames]
|
||||
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
else:
|
||||
episode_dict["dataset_from_index"] = [0]
|
||||
episode_dict["dataset_to_index"] = [num_frames]
|
||||
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
else:
|
||||
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
||||
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
||||
latest_num_frames = self.latest_episode["episode_index"][0]
|
||||
|
||||
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, flush buffer and prepare new parquet file
|
||||
self._flush_metadata_buffer()
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
self._close_writer()
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
# Add to buffer
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
episode_dict.update(episode_metadata)
|
||||
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
This allows users to customize storage organization without modifying the constructor.
|
||||
These settings control how episodes are chunked and how large files can grow before
|
||||
creating new ones.
|
||||
|
||||
Args:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info["chunks_size"] = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
def get_chunk_settings(self) -> dict[str, int]:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
f"{self.__class__.__name__}({{\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Total episodes: '{self.total_episodes}',\n"
|
||||
f" Total frames: '{self.total_frames}',\n"
|
||||
f" Features: '{feature_keys}',\n"
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
robot_type: str | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
metadata_buffer_size: int = 10,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION,
|
||||
fps,
|
||||
features,
|
||||
use_videos,
|
||||
robot_type,
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
@@ -1326,7 +852,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
logger.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
@@ -1365,7 +891,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
@@ -1375,7 +901,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding videos for episode {ep_idx}")
|
||||
logger.info(f"Encoding videos for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
@@ -1605,7 +1131,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
||||
)
|
||||
|
||||
@@ -1716,184 +1242,3 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._streaming_encoder = None
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: str | Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Number of Samples: {self.num_frames},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import VideoFrame
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: str | Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
if extra_keys:
|
||||
logger.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Number of Samples: {self.num_frames},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
@@ -17,7 +17,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
@@ -44,11 +44,11 @@ def create_initial_features(
|
||||
return features
|
||||
|
||||
|
||||
# Helper to filter state/action keys based on regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[str]) -> bool:
|
||||
# Helper to filter state/action keys based on compiled regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
return any(pat.search(key) for pat in patterns)
|
||||
|
||||
|
||||
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
|
||||
@@ -89,6 +89,8 @@ def aggregate_pipeline_dataset_features(
|
||||
Returns:
|
||||
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
|
||||
"""
|
||||
compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None
|
||||
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, patterns):
|
||||
if not is_image and not should_keep(key, compiled_patterns):
|
||||
continue
|
||||
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
|
||||
@@ -13,10 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
@@ -39,13 +42,35 @@ class EpisodeAwareSampler:
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -21,16 +22,13 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_delta_indices
|
||||
from lerobot.datasets.io_utils import item_to_torch
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
LookAheadError,
|
||||
LookBackError,
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
get_delta_indices,
|
||||
is_float_in_list,
|
||||
item_to_torch,
|
||||
safe_shard,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
@@ -40,6 +38,164 @@ from lerobot.datasets.video_utils import (
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable[T]:
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
|
||||
|
||||
+42
-995
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,8 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
@@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
@@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str:
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
@@ -118,7 +120,7 @@ def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
@@ -208,7 +210,7 @@ def decode_video_frames_torchvision(
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
@@ -244,7 +246,7 @@ def decode_video_frames_torchvision(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
@@ -348,7 +350,7 @@ def decode_video_frames_torchcodec(
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
logger.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
@@ -374,7 +376,7 @@ def decode_video_frames_torchcodec(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
@@ -408,14 +410,14 @@ def encode_video_frames(
|
||||
imgs_dir = Path(imgs_dir)
|
||||
|
||||
if video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
return
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
@@ -508,7 +510,7 @@ def concatenate_video_files(
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
logger.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
@@ -819,7 +821,7 @@ class StreamingVideoEncoder:
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
@@ -841,7 +843,7 @@ class StreamingVideoEncoder:
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
@@ -851,7 +853,7 @@ class StreamingVideoEncoder:
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
logger.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
@@ -863,7 +865,7 @@ class StreamingVideoEncoder:
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
logger.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
@@ -1071,13 +1073,13 @@ class VideoEncodingManager:
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
logger.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
logging.info("Recording stopped. Encoding remaining episodes...")
|
||||
logger.info("Recording stopped. Encoding remaining episodes...")
|
||||
|
||||
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
||||
end_ep = self.dataset.num_episodes
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
@@ -1094,7 +1096,7 @@ class VideoEncodingManager:
|
||||
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
if img_dir.exists():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
@@ -1105,8 +1107,8 @@ class VideoEncodingManager:
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -23,7 +23,8 @@ import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.datasets.io_utils import write_json
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict
|
||||
from lerobot.utils.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
|
||||
@@ -23,7 +23,7 @@ import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.datasets.utils import write_json
|
||||
from lerobot.datasets.io_utils import write_json
|
||||
from lerobot.utils.constants import SCHEDULER_STATE
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
|
||||
@@ -23,7 +23,7 @@ from torch import nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame
|
||||
from lerobot.types import PolicyAction, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
|
||||
|
||||
@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
||||
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
self.vqvae_model.discretized = torch.tensor(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
||||
self.vqvae_model.discretized.fill_(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
|
||||
@@ -33,21 +33,40 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Action feature keys
|
||||
ACTION_LINEAR_VEL = "linear.vel"
|
||||
ACTION_ANGULAR_VEL = "angular.vel"
|
||||
ACTION_LINEAR_VEL = "linear_velocity"
|
||||
ACTION_ANGULAR_VEL = "angular_velocity"
|
||||
|
||||
# Observation feature keys
|
||||
# Observation feature keys — cameras
|
||||
OBS_FRONT = "front"
|
||||
OBS_REAR = "rear"
|
||||
OBS_LINEAR_VEL = "linear.vel"
|
||||
OBS_BATTERY_LEVEL = "battery.level"
|
||||
OBS_ORIENTATION_DEG = "orientation.deg"
|
||||
OBS_GPS_LATITUDE = "gps.latitude"
|
||||
OBS_GPS_LONGITUDE = "gps.longitude"
|
||||
OBS_GPS_SIGNAL = "gps.signal"
|
||||
OBS_SIGNAL_LEVEL = "signal.level"
|
||||
|
||||
# Observation feature keys — telemetry
|
||||
OBS_SPEED = "speed"
|
||||
OBS_BATTERY_LEVEL = "battery_level"
|
||||
OBS_ORIENTATION = "orientation"
|
||||
OBS_GPS_LATITUDE = "gps_latitude"
|
||||
OBS_GPS_LONGITUDE = "gps_longitude"
|
||||
OBS_GPS_SIGNAL = "gps_signal"
|
||||
OBS_SIGNAL_LEVEL = "signal_level"
|
||||
OBS_VIBRATION = "vibration"
|
||||
OBS_LAMP_STATE = "lamp.state"
|
||||
OBS_LAMP = "lamp"
|
||||
|
||||
# Observation feature keys — IMU sensors
|
||||
OBS_ACCELEROMETER_X = "accelerometer_x"
|
||||
OBS_ACCELEROMETER_Y = "accelerometer_y"
|
||||
OBS_ACCELEROMETER_Z = "accelerometer_z"
|
||||
OBS_GYROSCOPE_X = "gyroscope_x"
|
||||
OBS_GYROSCOPE_Y = "gyroscope_y"
|
||||
OBS_GYROSCOPE_Z = "gyroscope_z"
|
||||
OBS_MAGNETOMETER_X = "magnetometer_filtered_x"
|
||||
OBS_MAGNETOMETER_Y = "magnetometer_filtered_y"
|
||||
OBS_MAGNETOMETER_Z = "magnetometer_filtered_z"
|
||||
|
||||
# Observation feature keys — wheel RPMs
|
||||
OBS_WHEEL_RPM_0 = "wheel_rpm_0"
|
||||
OBS_WHEEL_RPM_1 = "wheel_rpm_1"
|
||||
OBS_WHEEL_RPM_2 = "wheel_rpm_2"
|
||||
OBS_WHEEL_RPM_3 = "wheel_rpm_3"
|
||||
|
||||
|
||||
class EarthRoverMiniPlus(Robot):
|
||||
@@ -154,33 +173,60 @@ class EarthRoverMiniPlus(Robot):
|
||||
dict: Observation features with types/shapes:
|
||||
- front: (480, 640, 3) - Front camera RGB image
|
||||
- rear: (480, 640, 3) - Rear camera RGB image
|
||||
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: float - Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: float - GPS latitude coordinate
|
||||
- gps.longitude: float - GPS longitude coordinate
|
||||
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: float - Network signal level (0-1, normalized from 0-5)
|
||||
- speed: float - Current speed (raw SDK value)
|
||||
- battery_level: float - Battery level (0-100)
|
||||
- orientation: float - Robot orientation in degrees
|
||||
- gps_latitude: float - GPS latitude coordinate
|
||||
- gps_longitude: float - GPS longitude coordinate
|
||||
- gps_signal: float - GPS signal strength (percentage)
|
||||
- signal_level: float - Network signal level (0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp.state: float - Lamp state (0=off, 1=on)
|
||||
- lamp: float - Lamp state (0=off, 1=on)
|
||||
- accelerometer_x: float - Accelerometer X axis (raw SDK value)
|
||||
- accelerometer_y: float - Accelerometer Y axis (raw SDK value)
|
||||
- accelerometer_z: float - Accelerometer Z axis (raw SDK value)
|
||||
- gyroscope_x: float - Gyroscope X axis (raw SDK value)
|
||||
- gyroscope_y: float - Gyroscope Y axis (raw SDK value)
|
||||
- gyroscope_z: float - Gyroscope Z axis (raw SDK value)
|
||||
- magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value)
|
||||
- magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value)
|
||||
- magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value)
|
||||
- wheel_rpm_0: float - Wheel 0 RPM
|
||||
- wheel_rpm_1: float - Wheel 1 RPM
|
||||
- wheel_rpm_2: float - Wheel 2 RPM
|
||||
- wheel_rpm_3: float - Wheel 3 RPM
|
||||
"""
|
||||
return {
|
||||
# Cameras (height, width, channels)
|
||||
OBS_FRONT: (480, 640, 3),
|
||||
OBS_REAR: (480, 640, 3),
|
||||
# Motion state
|
||||
OBS_LINEAR_VEL: float,
|
||||
# Robot state
|
||||
# Telemetry
|
||||
OBS_SPEED: float,
|
||||
OBS_BATTERY_LEVEL: float,
|
||||
OBS_ORIENTATION_DEG: float,
|
||||
# GPS
|
||||
OBS_ORIENTATION: float,
|
||||
OBS_GPS_LATITUDE: float,
|
||||
OBS_GPS_LONGITUDE: float,
|
||||
OBS_GPS_SIGNAL: float,
|
||||
# Sensors
|
||||
OBS_SIGNAL_LEVEL: float,
|
||||
OBS_VIBRATION: float,
|
||||
OBS_LAMP_STATE: float,
|
||||
OBS_LAMP: float,
|
||||
# IMU — accelerometer
|
||||
OBS_ACCELEROMETER_X: float,
|
||||
OBS_ACCELEROMETER_Y: float,
|
||||
OBS_ACCELEROMETER_Z: float,
|
||||
# IMU — gyroscope
|
||||
OBS_GYROSCOPE_X: float,
|
||||
OBS_GYROSCOPE_Y: float,
|
||||
OBS_GYROSCOPE_Z: float,
|
||||
# IMU — magnetometer
|
||||
OBS_MAGNETOMETER_X: float,
|
||||
OBS_MAGNETOMETER_Y: float,
|
||||
OBS_MAGNETOMETER_Z: float,
|
||||
# Wheel RPMs
|
||||
OBS_WHEEL_RPM_0: float,
|
||||
OBS_WHEEL_RPM_1: float,
|
||||
OBS_WHEEL_RPM_2: float,
|
||||
OBS_WHEEL_RPM_3: float,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
@@ -189,8 +235,8 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
Returns:
|
||||
dict: Action features with types:
|
||||
- linear.vel: float - Target linear velocity
|
||||
- angular.vel: float - Target angular velocity
|
||||
- linear_velocity: float - Target linear velocity (-1 to 1)
|
||||
- angular_velocity: float - Target angular velocity (-1 to 1)
|
||||
"""
|
||||
return {
|
||||
ACTION_LINEAR_VEL: float,
|
||||
@@ -201,19 +247,29 @@ class EarthRoverMiniPlus(Robot):
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
|
||||
Frames are decoded from base64 and converted from BGR to RGB format.
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
Sensor arrays (accels, gyros, mags, rpms) each contain entries of
|
||||
[values..., timestamp]; the latest reading from each array is used.
|
||||
|
||||
Returns:
|
||||
RobotObservation: Observation containing:
|
||||
- front: Front camera image (480, 640, 3) in RGB format
|
||||
- rear: Rear camera image (480, 640, 3) in RGB format
|
||||
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: GPS latitude coordinate
|
||||
- gps.longitude: GPS longitude coordinate
|
||||
- gps.signal: GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: Network signal level (0-1, normalized from 0-5)
|
||||
- vibration: Vibration sensor reading
|
||||
- lamp.state: Lamp state (0=off, 1=on)
|
||||
- speed: float - Current speed (raw SDK value)
|
||||
- battery_level: float - Battery level (0-100)
|
||||
- orientation: float - Robot orientation in degrees
|
||||
- gps_latitude: float - GPS latitude coordinate
|
||||
- gps_longitude: float - GPS longitude coordinate
|
||||
- gps_signal: float - GPS signal strength (percentage)
|
||||
- signal_level: float - Network signal level (0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp: float - Lamp state (0=off, 1=on)
|
||||
- accelerometer_x/y/z: float - Accelerometer axes (raw SDK value)
|
||||
- gyroscope_x/y/z: float - Gyroscope axes (raw SDK value)
|
||||
- magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value)
|
||||
- wheel_rpm_0/1/2/3: float - Wheel RPMs
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
@@ -235,22 +291,41 @@ class EarthRoverMiniPlus(Robot):
|
||||
# Get robot state from SDK
|
||||
robot_data = self._get_robot_data()
|
||||
|
||||
# Motion state
|
||||
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
|
||||
# Telemetry
|
||||
observation[OBS_SPEED] = float(robot_data["speed"])
|
||||
observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"])
|
||||
observation[OBS_ORIENTATION] = float(robot_data["orientation"])
|
||||
observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"])
|
||||
observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"])
|
||||
observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"])
|
||||
observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"])
|
||||
observation[OBS_VIBRATION] = float(robot_data["vibration"])
|
||||
observation[OBS_LAMP] = float(robot_data["lamp"])
|
||||
|
||||
# Robot state
|
||||
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
|
||||
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
|
||||
# Accelerometer — latest reading from accels array [x, y, z, ts]
|
||||
accel = self._latest_sensor_reading(robot_data, "accels", n_values=3)
|
||||
observation[OBS_ACCELEROMETER_X] = accel[0]
|
||||
observation[OBS_ACCELEROMETER_Y] = accel[1]
|
||||
observation[OBS_ACCELEROMETER_Z] = accel[2]
|
||||
|
||||
# GPS data
|
||||
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
|
||||
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
|
||||
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
|
||||
# Gyroscope — latest reading from gyros array [x, y, z, ts]
|
||||
gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3)
|
||||
observation[OBS_GYROSCOPE_X] = gyro[0]
|
||||
observation[OBS_GYROSCOPE_Y] = gyro[1]
|
||||
observation[OBS_GYROSCOPE_Z] = gyro[2]
|
||||
|
||||
# Sensors
|
||||
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
|
||||
observation[OBS_VIBRATION] = robot_data["vibration"]
|
||||
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
|
||||
# Magnetometer — latest reading from mags array [x, y, z, ts]
|
||||
mag = self._latest_sensor_reading(robot_data, "mags", n_values=3)
|
||||
observation[OBS_MAGNETOMETER_X] = mag[0]
|
||||
observation[OBS_MAGNETOMETER_Y] = mag[1]
|
||||
observation[OBS_MAGNETOMETER_Z] = mag[2]
|
||||
|
||||
# Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts]
|
||||
rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4)
|
||||
observation[OBS_WHEEL_RPM_0] = rpm[0]
|
||||
observation[OBS_WHEEL_RPM_1] = rpm[1]
|
||||
observation[OBS_WHEEL_RPM_2] = rpm[2]
|
||||
observation[OBS_WHEEL_RPM_3] = rpm[3]
|
||||
|
||||
return observation
|
||||
|
||||
@@ -260,11 +335,12 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
Args:
|
||||
action: Action dict with keys:
|
||||
- linear.vel: Target linear velocity (-1 to 1)
|
||||
- angular.vel: Target angular velocity (-1 to 1)
|
||||
- linear_velocity: Target linear velocity (-1 to 1)
|
||||
- angular_velocity: Target angular velocity (-1 to 1)
|
||||
|
||||
Returns:
|
||||
RobotAction: The action that was sent (matches action_features keys)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
|
||||
@@ -272,18 +348,14 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
|
||||
|
||||
# Send command to SDK
|
||||
try:
|
||||
self._send_command_to_sdk(linear, angular)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending action: {e}")
|
||||
|
||||
# Return action in format matching action_features
|
||||
return {
|
||||
ACTION_LINEAR_VEL: linear,
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
@@ -394,11 +466,27 @@ class EarthRoverMiniPlus(Robot):
|
||||
logger.error(f"Error decoding image: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]:
|
||||
"""Extract the latest sensor reading from an SDK sensor array.
|
||||
|
||||
The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``,
|
||||
``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``.
|
||||
This helper returns the *n_values* leading floats from the last entry,
|
||||
falling back to zeros when the key is missing or the array is empty.
|
||||
"""
|
||||
readings = robot_data.get(key)
|
||||
if readings and len(readings) > 0:
|
||||
latest = readings[-1]
|
||||
return [float(v) for v in latest[:n_values]]
|
||||
return [0.0] * n_values
|
||||
|
||||
def _get_robot_data(self) -> dict:
|
||||
"""Get robot telemetry data from SDK.
|
||||
|
||||
Returns:
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS,
|
||||
and sensor arrays (accels, gyros, mags, rpms):
|
||||
- Current data (if request succeeds)
|
||||
- Cached data (if request fails but cache exists)
|
||||
- Default values (if request fails and no cache exists yet)
|
||||
@@ -420,19 +508,23 @@ class EarthRoverMiniPlus(Robot):
|
||||
# Fallback: use cache or default values
|
||||
if self._last_robot_data is not None:
|
||||
return self._last_robot_data
|
||||
else:
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
}
|
||||
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
"accels": [],
|
||||
"gyros": [],
|
||||
"mags": [],
|
||||
"rpms": [],
|
||||
}
|
||||
|
||||
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
|
||||
"""Send control command to SDK.
|
||||
|
||||
@@ -45,8 +45,9 @@ from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import write_stats
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
|
||||
from lerobot.datasets.io_utils import write_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,19 @@ from huggingface_hub import HfApi, snapshot_download
|
||||
from requests import HTTPError
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
|
||||
from lerobot.datasets.io_utils import (
|
||||
cast_stats_to_numpy,
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
load_info,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -70,17 +82,8 @@ from lerobot.datasets.utils import (
|
||||
LEGACY_EPISODES_PATH,
|
||||
LEGACY_EPISODES_STATS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
cast_stats_to_numpy,
|
||||
flatten_dict,
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
@@ -83,10 +83,10 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -341,8 +341,8 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
def action_features(self) -> dict:
|
||||
"""Return action format for rover (linear and angular velocities)."""
|
||||
return {
|
||||
"linear.vel": float,
|
||||
"angular.vel": float,
|
||||
"linear_velocity": float,
|
||||
"angular_velocity": float,
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -366,7 +366,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
Get the current action based on pressed keys.
|
||||
|
||||
Returns:
|
||||
RobotAction with 'linear.vel' and 'angular.vel' keys
|
||||
RobotAction with 'linear_velocity' and 'angular_velocity' keys.
|
||||
"""
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
@@ -427,6 +427,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
return {
|
||||
"linear.vel": linear_velocity,
|
||||
"angular.vel": angular_velocity,
|
||||
"linear_velocity": linear_velocity,
|
||||
"angular_velocity": angular_velocity,
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.utils import load_json, write_json
|
||||
from lerobot.datasets.io_utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
|
||||
|
||||
def test_dataset_config_valid():
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_negative_episodes():
|
||||
with pytest.raises(ValueError, match="non-negative"):
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_duplicate_episodes():
|
||||
with pytest.raises(ValueError, match="duplicates"):
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_none_episodes_ok():
|
||||
DatasetConfig(repo_id="user/repo", episodes=None)
|
||||
|
||||
|
||||
def test_dataset_config_empty_episodes_ok():
|
||||
DatasetConfig(repo_id="user/repo", episodes=[])
|
||||
@@ -260,8 +260,8 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
|
||||
# Mock the revision to prevent Hub calls during dataset loading
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
|
||||
@@ -311,8 +311,8 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
|
||||
# Mock the revision to prevent Hub calls during dataset loading
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
|
||||
@@ -367,8 +367,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
|
||||
@@ -492,8 +492,8 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
|
||||
# Load the aggregated dataset
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
|
||||
@@ -562,8 +562,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
|
||||
@@ -590,8 +590,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
|
||||
|
||||
@@ -67,8 +67,8 @@ def test_delete_single_episode(sample_dataset, tmp_path):
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -93,8 +93,8 @@ def test_delete_multiple_episodes(sample_dataset, tmp_path):
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -150,8 +150,8 @@ def test_split_by_episodes(sample_dataset, tmp_path):
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
@@ -193,8 +193,8 @@ def test_split_by_fractions(sample_dataset, tmp_path):
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
@@ -270,8 +270,8 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact
|
||||
dataset2.finalize()
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||
@@ -310,8 +310,8 @@ def test_add_features_with_values(sample_dataset, tmp_path):
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
@@ -346,8 +346,8 @@ def test_add_features_with_callable(sample_dataset, tmp_path):
|
||||
"reward": (compute_reward, feature_info),
|
||||
}
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
@@ -401,8 +401,8 @@ def test_modify_features_add_and_remove(sample_dataset, tmp_path):
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
||||
@@ -434,8 +434,8 @@ def test_modify_features_only_add(sample_dataset, tmp_path):
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
||||
@@ -457,8 +457,8 @@ def test_modify_features_only_remove(sample_dataset, tmp_path):
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
@@ -494,8 +494,8 @@ def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
}
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
@@ -521,8 +521,8 @@ def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
def test_remove_multiple_features(sample_dataset, tmp_path):
|
||||
"""Test removing multiple features at once."""
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
@@ -576,8 +576,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path):
|
||||
camera_to_remove = camera_keys[0]
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
|
||||
@@ -598,8 +598,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path):
|
||||
def test_complex_workflow_integration(sample_dataset, tmp_path):
|
||||
"""Test a complex workflow combining multiple operations."""
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
@@ -647,8 +647,8 @@ def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -671,8 +671,8 @@ def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -699,8 +699,8 @@ def test_split_three_ways(sample_dataset, tmp_path):
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
@@ -732,8 +732,8 @@ def test_split_preserves_stats(sample_dataset, tmp_path):
|
||||
splits = {"train": [0, 1, 2], "val": [3, 4]}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
@@ -790,8 +790,8 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa
|
||||
datasets.append(dataset)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||
@@ -832,8 +832,8 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f
|
||||
dataset2.finalize()
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||
@@ -866,8 +866,8 @@ def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
@@ -890,8 +890,8 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path):
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
@@ -919,8 +919,8 @@ def test_delete_consecutive_episodes(sample_dataset, tmp_path):
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -943,8 +943,8 @@ def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -971,8 +971,8 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path):
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
@@ -999,8 +999,8 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
@@ -1020,7 +1020,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
|
||||
# Get original chunk/file indices from first episode
|
||||
if train_dataset.meta.episodes is None:
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
|
||||
train_dataset.meta.episodes = load_episodes(train_dataset.meta.root)
|
||||
original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes]
|
||||
@@ -1040,7 +1040,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
|
||||
# Check that chunk/file indices are preserved
|
||||
if modified_dataset.meta.episodes is None:
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
|
||||
modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root)
|
||||
new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes]
|
||||
@@ -1194,7 +1194,7 @@ def test_modify_tasks_in_place(sample_dataset):
|
||||
|
||||
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
|
||||
"""Test that original tasks are kept when using episode_tasks without new_task."""
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if sample_dataset.meta.episodes is None:
|
||||
@@ -1229,8 +1229,8 @@ def test_convert_image_to_video_dataset(tmp_path):
|
||||
output_dir = tmp_path / "pusht_video"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
@@ -1292,8 +1292,8 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
|
||||
output_dir = tmp_path / "pusht_video_subset"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
@@ -19,7 +19,9 @@ import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
|
||||
@@ -29,20 +29,19 @@ import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
create_branch,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
hw_to_dataset_features,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.envs.factory import make_env_config
|
||||
@@ -1329,7 +1328,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
|
||||
dataset.meta.episodes = load_episodes(dataset.root)
|
||||
assert dataset.meta.episodes is not None
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
from lerobot.datasets.feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
)
|
||||
|
||||
@@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory):
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
with patch("builtins.print") as mock_print:
|
||||
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
|
||||
write_image(image_array, fpath)
|
||||
mock_print.assert_called()
|
||||
mock_logger.error.assert_called()
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
@@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
mock_print.assert_called()
|
||||
mock_logger.error.assert_called()
|
||||
assert not fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
@@ -13,13 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import (
|
||||
from lerobot.datasets.io_utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||
@@ -106,3 +109,28 @@ def test_shuffle():
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
|
||||
def test_negative_drop_last_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
||||
|
||||
|
||||
def test_all_episodes_dropped_raises():
|
||||
# All episodes have 1 frame, drop_n_first_frames=1 removes all
|
||||
with pytest.raises(ValueError, match="No valid frames remain"):
|
||||
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
||||
|
||||
|
||||
def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
|
||||
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
||||
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
Vendored
+6
-5
@@ -26,7 +26,10 @@ import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -35,8 +38,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
flatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
@@ -453,8 +454,8 @@ def lerobot_dataset_metadata_factory(
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
Vendored
+7
-5
@@ -20,17 +20,19 @@ import pandas as pd
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_hf_dataset_size_in_mb,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
|
||||
@@ -28,7 +28,8 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
@@ -41,6 +42,8 @@ from lerobot.policies.factory import (
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -459,3 +462,45 @@ def test_act_temporal_ensembler():
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_vqbet_discretize_keeps_buffers_on_device():
|
||||
"""Regression test: VQBeTHead.discretize() must not move registered buffers off the model device.
|
||||
|
||||
Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the
|
||||
registered buffer with a new CPU tensor, causing DDP to crash with:
|
||||
RuntimeError: No backend type associated with device type cpu
|
||||
The fix uses `.fill_(True)` to update in-place, preserving device placement.
|
||||
"""
|
||||
config = VQBeTConfig()
|
||||
config.input_features = {
|
||||
OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
# Tiny sizes for fast CPU/GPU execution.
|
||||
config.n_vqvae_training_steps = 3
|
||||
config.vqvae_n_embed = 8
|
||||
config.vqvae_embedding_dim = 32
|
||||
config.vqvae_enc_hidden_dim = 32
|
||||
config.action_chunk_size = 2
|
||||
config.crop_shape = (84, 84)
|
||||
|
||||
head = VQBeTHead(config).to(DEVICE)
|
||||
vqvae = head.vqvae_model
|
||||
|
||||
dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE)
|
||||
n_steps = config.n_vqvae_training_steps
|
||||
for _ in range(n_steps):
|
||||
head.discretize(n_steps, dummy_actions)
|
||||
|
||||
assert vqvae.discretized.device.type == torch.device(DEVICE).type, (
|
||||
"vqvae_model.discretized was moved off the model device after discretize(). "
|
||||
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
|
||||
)
|
||||
assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, (
|
||||
"vq_layer.freeze_codebook was moved off the model device after discretize(). "
|
||||
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
|
||||
)
|
||||
|
||||
@@ -71,8 +71,8 @@ def test_record_and_resume(tmp_path):
|
||||
cfg.resume = True
|
||||
# Mock the revision to prevent Hub calls during resume
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "record")
|
||||
@@ -115,8 +115,8 @@ def test_record_and_replay(tmp_path):
|
||||
|
||||
# Mock the revision to prevent Hub calls during replay
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")
|
||||
|
||||
Reference in New Issue
Block a user