Compare commits

..

11 Commits

Author SHA1 Message Date
Khalil Meftah d734b0afa3 docs: reframe HIL guide around general intervention workflow 2026-02-26 11:21:21 +01:00
Khalil Meftah 2aac055fee docs: make hil data collection tutorial general-purpose 2026-02-25 17:58:04 +01:00
Steven Palma 5c6a97cc4b Merge branch 'main' into feat/RaC 2026-02-24 13:25:56 +01:00
Steven Palma c0b9621d8a Merge branch 'main' into feat/RaC 2026-02-23 15:23:24 +01:00
Steven Palma d3fb69e85b chore(examples): remove pedal 2026-02-23 11:37:29 +01:00
Steven Palma e011fcbad4 chore(style): fix pre-commit 2026-02-23 11:31:57 +01:00
Steven Palma fee1934af4 Merge branch 'main' into feat/RaC
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 11:17:37 +01:00
Pepijn 1bddfd2cc3 Merge branch 'main' into feat/RaC 2026-01-26 17:27:46 +01:00
Pepijn 5ac8fe3243 Rename to hil, use two seperta scripts one rtc one synch 2026-01-23 12:47:54 +01:00
Pepijn 467981eaef Add changes from openarms experiments 2026-01-21 16:39:53 +01:00
Pepijn 27eeff7535 Add RaC doc and example 2025-12-30 09:57:40 +01:00
27 changed files with 1562 additions and 887 deletions
-2
View File
@@ -173,8 +173,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
-1
View File
@@ -1,3 +1,2 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
-2
View File
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
RUN uv pip install --no-cache ".[all]"
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
# Copy the rest of the application source code
# Make sure to have the git-LFS files for testing
COPY --chown=user_lerobot:user_lerobot . .
+2
View File
@@ -17,6 +17,8 @@
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
- local: hil_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
title: "Tutorials"
+237
View File
@@ -0,0 +1,237 @@
# Human-In-the-Loop Data Collection
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data — recovery movements and corrections — is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
---
## Why Human-In-the-Loop?
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
- Running the trained policy on the real robot
- Having a human intervene when the robot is about to fail
- Recording the human's recovery and correction as training data
- Fine-tuning the policy on the combined dataset
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
---
## How It Works
During a HIL session, the human operator follows this loop within each episode:
1. **Watch** the policy run autonomously
2. **Pause** when failure is imminent — the robot holds its position
3. **Take control** — teleoperate the robot back to a good state (recovery), then correct the behavior
4. **Return control to the policy** — the policy resumes autonomous execution
5. Repeat steps 24 as many times as needed during the episode
6. **End the episode** when the task is complete, save and move on to the next rollout
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
This process can be repeated iteratively: deploy, collect, fine-tune, repeat — each round targeting the current policy's failure modes.
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Policy v0 (trained on demos) │
│ ↓ │
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
│ ↓ │
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
│ ↓ │
│ ... (repeat until satisfactory performance) │
└─────────────────────────────────────────────────────────────────────────┘
```
---
## Hardware Requirements
### Teleoperator Requirements
The HIL data collection scripts require **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators:**
- `so101_leader` - SO-101 Leader Arm
- `openarms_mini` - OpenArms Mini (via third-party plugin)
---
## Scripts
Two scripts are provided depending on your policy's inference speed:
| Script | Use Case | Models |
| ---------------------------- | ------------------------------------------ | --------------------- |
| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy |
| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA |
---
## Step-by-Step Guide
### Step 1: Pre-train a Base Policy
First, train a policy on your demonstration dataset:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/demo-dataset \
--policy.type=pi0 \
--output_dir=outputs/pretrain \
--batch_size=32 \
--steps=50000
```
### Step 2: Collect HIL Data
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/rac/hil_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Pick up the cube and place it in the bowl" \
--dataset.num_episodes=50
```
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
For models with high inference latency, use the RTC script for smooth execution:
```bash
python examples/rac/hil_data_collection_rtc.py \
--robot.type=so100_follower \
--teleop.type=so100_leader \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Pick up the cube" \
--rtc.execution_horizon=20 \
--interpolation=true
```
**Controls (Conceptual):**
The interaction model is:
- **Pause input**: pause autonomous policy execution
- **Takeover input**: transfer control to the human operator and record intervention data
- **Return-to-policy input**: hand control back to the policy and continue the same episode
- **Episode control inputs**: save/re-record/stop/reset as needed
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
**The HIL Protocol:**
1. Watch the policy run autonomously (teleop is idle/free)
2. When you see imminent failure, trigger the **pause input**
- Policy stops
- Teleoperator moves to match robot position (torque enabled)
- No frames recorded during pause
3. Trigger the **takeover input** to take control
- Teleoperator torque disabled, free to move
- **Recovery**: Teleoperate the robot back to a good state
- **Correction**: Correct the behavior
- All movements are recorded
4. Trigger the **return-to-policy input**
- Policy resumes autonomous execution from the current state
- You can intervene again at any time (repeat steps 24)
5. End and save the episode when the task is complete (or episode time limit is reached)
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
7. Start the next episode
**Foot Pedal Setup (Linux):**
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
```bash
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
```
### Step 3: Fine-tune the Policy
Fine-tune on the combined demonstration + HIL data:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/hil-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/hil_finetune \
--steps=20000
```
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
---
## Tips for Effective HIL Collection
### When to Intervene
Intervene when you see:
- Robot about to make an irreversible mistake
- Robot hesitating or showing uncertain behavior
- Robot deviating from the expected trajectory
### Recovery: Teleoperating Back to a Good State
During recovery, teleoperate the robot back to a state where:
- The robot is in a familiar, in-distribution configuration
- The current subtask can still be completed
- The recovery trajectory itself is informative training data
### Quality of Corrections
During correction:
- Provide **confident, clean** trajectories
- Complete the current subtask fully
- Don't overcorrect or add unnecessary movements
---
## Related Work
This HIL data collection approach builds on ideas from interactive imitation learning, including DAgger (Ross et al., 2011), HG-DAgger (Kelly et al., 2019), RaC (Hu et al., 2025), and RECAP (Physical Intelligence, 2025). See those works for a deeper treatment of the theory behind human-in-the-loop policy improvement.
```bibtex
@article{ross2011dagger,
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
year={2011}
}
@article{kelly2019hgdagger,
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
journal={arXiv preprint arXiv:1810.02890},
year={2019}
}
@article{hu2025rac,
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
journal={arXiv preprint arXiv:2509.07953},
year={2025}
}
@article{pi2025recap,
title={π0.6: a VLA That Learns From Experience},
author={Physical Intelligence},
year={2025}
}
```
+351
View File
@@ -0,0 +1,351 @@
#!/usr/bin/env python
"""
Human-in-the-Loop (HIL) Data Collection with Policy Rollout.
Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous
inference. For large models with high inference latency, use hil_data_collection_rtc.py.
The workflow:
1. Policy runs autonomously
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/hil_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/hil_dataset \
--dataset.single_task="Pick up the cube"
"""
import logging
import time
from dataclasses import dataclass
from pprint import pformat
from typing import Any
import torch
from hil_utils import (
HILDatasetConfig,
init_keyboard_listener,
make_identity_processors,
print_controls,
reset_loop,
teleop_disable_torque,
teleop_smooth_move_to,
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logger = logging.getLogger(__name__)
@dataclass
class HILConfig:
robot: RobotConfig
teleop: TeleoperatorConfig
dataset: HILDatasetConfig
policy: PreTrainedConfig | None = None
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
@safe_stop_image_writer
def rollout_loop(
robot: Robot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset: LeRobotDataset,
events: dict,
cfg: HILConfig,
):
"""Rollout loop with standard synchronous inference."""
fps = cfg.dataset.fps
device = get_safe_torch_device(cfg.device)
policy.reset()
preprocessor.reset()
postprocessor.reset()
frame_buffer = []
teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
robot_action: dict[str, Any] = {}
action_keys = sorted(robot.action_features.keys())
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(fps)
timestamp = 0
start_t = time.perf_counter()
while timestamp < cfg.dataset.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Transition to paused state
if events["policy_paused"] and not was_paused:
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# Takeover
if waiting_for_takeover and events["start_next_episode"]:
teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
if events["correction_active"]:
robot_action = teleop.get_action()
robot.send_action(robot_action)
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
elif waiting_for_takeover or events["policy_paused"]:
if last_action:
robot.send_action(last_action)
else:
# Policy execution with optional interpolation
if interpolator.needs_new_action():
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=cfg.dataset.single_task,
robot_type=robot.robot_type,
)
robot_action = make_robot_action(action_values, dataset.features)
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
interpolator.add(action_tensor)
interp_action = interpolator.get()
if interp_action is not None:
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
robot.send_action(robot_action)
last_action = robot_action
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
if cfg.display_data and robot_action:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
if (sleep_time := control_interval - dt) > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
teleop_disable_torque(teleop)
for frame in frame_buffer:
dataset.add_frame(frame)
@parser.wrap()
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
"""Main HIL data collection function."""
init_logging()
logger.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="hil_collection")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and robot.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
teleop.connect()
listener, events = init_keyboard_listener()
print_controls(rtc=False)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
)
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
hil_collect()
if __name__ == "__main__":
main()
+513
View File
@@ -0,0 +1,513 @@
#!/usr/bin/env python
"""
Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC).
Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models
(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks
asynchronously in a background thread for smooth robot control.
For fast models (ACT, Diffusion), use hil_data_collection.py instead.
The workflow:
1. Policy runs autonomously with RTC
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/hil_data_collection_rtc.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/hil_rtc_dataset \
--dataset.single_task="Pick up the cube" \
--rtc.execution_horizon=20
"""
import logging
import math
import time
from dataclasses import dataclass, field
from pprint import pformat
from threading import Event, Lock, Thread
from typing import Any
import torch
from hil_utils import (
HILDatasetConfig,
init_keyboard_listener,
make_identity_processors,
print_controls,
reset_loop,
teleop_disable_torque,
teleop_smooth_move_to,
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logger = logging.getLogger(__name__)
@dataclass
class HILRTCConfig:
robot: RobotConfig
teleop: TeleoperatorConfig
dataset: HILDatasetConfig
policy: PreTrainedConfig | None = None
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20))
interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x)
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
use_torch_compile: bool = False # First compile takes minutes, disable for real-time
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
self.rtc.enabled = True
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
class ThreadSafeRobot:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: Robot):
self._robot = robot
self._lock = Lock()
def get_observation(self) -> dict[str, Any]:
with self._lock:
return self._robot.get_observation()
def send_action(self, action: dict) -> None:
with self._lock:
self._robot.send_action(action)
@property
def observation_features(self) -> dict:
return self._robot.observation_features
@property
def action_features(self) -> dict:
return self._robot.action_features
@property
def name(self) -> str:
return self._robot.name
@property
def robot_type(self) -> str:
return self._robot.robot_type
@property
def cameras(self):
return getattr(self._robot, "cameras", {})
def rtc_inference_thread(
policy: PreTrainedPolicy,
obs_holder: dict,
hw_features: dict,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
queue_holder: dict,
shutdown_event: Event,
policy_active: Event,
cfg: HILRTCConfig,
):
"""Background thread for RTC action chunk generation."""
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.dataset.fps
threshold = 30
while not shutdown_event.is_set():
if not policy_active.is_set():
time.sleep(0.01)
continue
queue = queue_holder.get("queue")
obs = obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(0.01)
continue
if queue.qsize() <= threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
prev_actions = queue.get_left_over()
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
for name in obs_batch:
obs_batch[name] = torch.from_numpy(obs_batch[name])
if "image" in name:
obs_batch[name] = obs_batch[name].float() / 255
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device)
obs_batch["task"] = [cfg.dataset.single_task]
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
preprocessed = preprocessor(obs_batch)
actions = policy.predict_action_chunk(
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
)
original = actions.squeeze(0).clone()
processed = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
queue.merge(original, processed, new_delay, idx_before)
logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}")
except Exception as e:
logger.error(f"[RTC] Error: {e}")
time.sleep(0.5)
else:
time.sleep(0.01)
@safe_stop_image_writer
def rollout_loop(
robot: ThreadSafeRobot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset: LeRobotDataset,
events: dict,
cfg: HILRTCConfig,
queue_holder: dict,
obs_holder: dict,
policy_active: Event,
hw_features: dict,
):
"""Rollout loop with RTC for asynchronous inference."""
fps = cfg.dataset.fps
policy.reset()
preprocessor.reset()
postprocessor.reset()
frame_buffer = []
teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(fps)
robot_action: dict[str, Any] = {}
timestamp = 0
start_t = time.perf_counter()
while timestamp < cfg.dataset.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Transition to paused state
if events["policy_paused"] and not was_paused:
policy_active.clear()
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# Takeover
if waiting_for_takeover and events["start_next_episode"]:
teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
obs_holder["obs"] = obs_filtered
if events["correction_active"]:
robot_action = teleop.get_action()
robot.send_action(robot_action)
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
elif waiting_for_takeover or events["policy_paused"]:
if last_action:
robot.send_action(last_action)
else:
# Policy execution with RTC
if not policy_active.is_set():
policy_active.set()
queue = queue_holder["queue"]
if interpolator.needs_new_action():
new_action = queue.get() if queue else None
if new_action is not None:
interpolator.add(new_action.cpu())
action_tensor = interpolator.get()
if action_tensor is not None:
robot_action = {
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
}
robot.send_action(robot_action)
last_action = robot_action
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
if cfg.display_data and robot_action:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
if (sleep_time := control_interval - dt) > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
policy_active.clear()
teleop_disable_torque(teleop)
for frame in frame_buffer:
dataset.add_frame(frame)
@parser.wrap()
def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
"""Main HIL data collection function with RTC."""
init_logging()
logger.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="hil_rtc_collection")
robot_raw = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot_raw.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot_raw.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
shutdown_event = Event()
policy_active = Event()
rtc_thread = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot_raw.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
# Load policy with RTC
policy_class = get_policy_class(cfg.policy.type)
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if hasattr(policy_config, "compile_model"):
policy_config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
policy.config.rtc_config = cfg.rtc
if hasattr(policy, "init_rtc_processor"):
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot_raw.connect()
robot = ThreadSafeRobot(robot_raw)
teleop.connect()
listener, events = init_keyboard_listener()
queue_holder = {"queue": ActionQueue(cfg.rtc)}
obs_holder = {"obs": None, "robot_type": robot.robot_type}
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
rtc_thread = Thread(
target=rtc_inference_thread,
args=(
policy,
obs_holder,
hw_features,
preprocessor,
postprocessor,
queue_holder,
shutdown_event,
policy_active,
cfg,
),
daemon=True,
)
rtc_thread.start()
print_controls(rtc=True)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
queue_holder["queue"] = ActionQueue(cfg.rtc)
rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
queue_holder=queue_holder,
obs_holder=obs_holder,
policy_active=policy_active,
hw_features=hw_features,
)
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
shutdown_event.set()
policy_active.clear()
if rtc_thread and rtc_thread.is_alive():
rtc_thread.join(timeout=2.0)
if dataset:
dataset.finalize()
if robot_raw.is_connected:
robot_raw.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
hil_rtc_collect()
if __name__ == "__main__":
main()
+206
View File
@@ -0,0 +1,206 @@
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque")
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if teleop_has_motor_control(teleop):
teleop.bus.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if teleop_has_motor_control(teleop):
teleop.bus.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = int(duration_s * fps)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()})
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
print("\n[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
print("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
print("\n[HIL] ⏸ PAUSED - Press 'c' to take control")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
print("\n[HIL] ▶ Taking control...")
events["start_next_episode"] = True
elif key == keyboard.Key.right:
print("[HIL] → End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("[HIL] ← Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
print("\n" + "=" * 60)
print(" [HIL] RESET")
print("=" * 60)
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
print(" Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
print(" Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
print("\n" + "=" * 60)
print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else ""))
print("=" * 60)
print()
print(" Controls:")
print(" SPACE - Pause policy")
print(" c - Take control")
print(" → - End episode")
print(" ESC - Stop and push to hub")
print("=" * 60 + "\n")
+9 -8
View File
@@ -83,9 +83,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
@@ -151,6 +149,7 @@ class RTCDemoConfig(HubMixin):
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
@@ -351,20 +350,22 @@ def actor_control(
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
-3
View File
@@ -214,9 +214,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
-7
View File
@@ -7,13 +7,6 @@
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
{% if repo_id is defined and repo_id %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
## Dataset Description
{{ dataset_description | default("", true) }}
+18 -317
View File
@@ -47,7 +47,6 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
load_info,
update_chunk_file_indices,
write_info,
write_stats,
@@ -568,22 +567,20 @@ def _copy_and_reindex_data(
def _keep_episodes_from_video_with_av(
input_path: Path,
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
episodes_to_keep: list[tuple[float, float]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
This function decodes frames from specified frame ranges and re-encodes them with
This function decodes frames from specified time ranges and re-encodes them with
properly reset timestamps to ensure monotonic progression.
Args:
input_path: Source video file path.
output_path: Destination video file path.
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
@@ -625,10 +622,9 @@ def _keep_episodes_from_video_with_av(
# Create set of (start, end) ranges for fast lookup.
# Convert to a sorted list for efficient checking.
frame_ranges = sorted(episodes_to_keep)
time_ranges = sorted(episodes_to_keep)
# Track frame index for setting PTS and current range being processed.
src_frame_count = 0
frame_count = 0
range_idx = 0
@@ -638,20 +634,21 @@ def _keep_episodes_from_video_with_av(
if frame is None:
continue
# Check if frame is in any of our desired frame ranges.
# Get frame timestamp.
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
# Check if frame is in any of our desired time ranges.
# Skip ranges that have already passed.
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
range_idx += 1
# If we've passed all ranges, stop processing.
if range_idx >= len(frame_ranges):
if range_idx >= len(time_ranges):
break
# Check if frame is in current range.
start_frame = frame_ranges[range_idx][0]
if src_frame_count < start_frame:
src_frame_count += 1
start_ts, end_ts = time_ranges[range_idx]
if frame_time < start_ts:
continue
# Frame is in range - create a new frame with reset timestamps.
@@ -664,7 +661,6 @@ def _keep_episodes_from_video_with_av(
for pkt in v_out.encode(new_frame):
out.mux(pkt)
src_frame_count += 1
frame_count += 1
# Flush encoder.
@@ -753,17 +749,15 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of frame ranges to keep, in sorted order.
# Build list of time ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
episodes_to_keep_ranges: list[tuple[int, int]] = []
episodes_to_keep_ranges: list[tuple[float, float]] = []
for old_idx in sorted_keep_episodes:
src_ep = src_dataset.meta.episodes[old_idx]
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
assert src_ep["length"] == to_frame - from_frame, (
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
)
episodes_to_keep_ranges.append((from_frame, to_frame))
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
episodes_to_keep_ranges.append((from_ts, to_ts))
# Use PyAV filters to efficiently re-encode only the desired segments.
assert src_dataset.meta.video_path is not None
@@ -1775,296 +1769,3 @@ def convert_image_to_video_dataset(
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def trim_episodes_by_frames(
dataset: LeRobotDataset,
episode_frames_to_keep: dict[int, list[int]],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim multiple episodes to keep only specific frames.
This function creates a new dataset where the specified episodes contain only
the frames at the given indices. All other episodes are copied as-is.
Args:
dataset: The source LeRobotDataset.
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
Returns:
A new LeRobotDataset with the trimmed episodes.
"""
if not episode_frames_to_keep:
raise ValueError("No episodes to trim")
for ep_idx in episode_frames_to_keep:
if ep_idx >= dataset.meta.total_episodes:
raise ValueError(f"Episode {ep_idx} does not exist")
if not episode_frames_to_keep[ep_idx]:
raise ValueError(f"No frames to keep for episode {ep_idx}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_trimmed"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
# Create new metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Build set of all frames to keep (for episodes being trimmed)
# and compute new frame counts per episode
all_keep_frames: set[int] = set()
trimmed_frame_counts: dict[int, int] = {}
for ep_idx, frames in episode_frames_to_keep.items():
all_keep_frames.update(frames)
trimmed_frame_counts[ep_idx] = len(frames)
# Copy and filter data
_copy_and_reindex_data_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep, all_keep_frames
)
# Handle videos if present
if dataset.meta.video_keys:
_copy_and_reindex_videos_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep
)
# Copy episode metadata
_copy_and_reindex_episodes_metadata_for_multi_trim(
dataset, new_meta, trimmed_frame_counts
)
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
# Return the metadata instead of trying to load as LeRobotDataset
# This avoids Hub validation issues when the repo doesn't exist yet
return new_meta
# Keep old function for backward compatibility
def trim_episode_by_frames(
dataset: LeRobotDataset,
episode_index: int,
keep_frame_indices: list[int],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
return trim_episodes_by_frames(
dataset,
episode_frames_to_keep={episode_index: keep_frame_indices},
output_dir=output_dir,
repo_id=repo_id,
)
def _copy_and_reindex_data_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
all_keep_frames: set[int],
) -> None:
"""Copy data files with frame-level filtering for multiple episodes."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Copy tasks
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
# Tasks are stored with task string as index
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
# Get all parquet files
data_dir = src_dataset.root / "data"
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
trim_episode_set = set(episode_frames_to_keep.keys())
global_index = 0
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(parquet_path)
# Filter: keep all frames from non-trimmed episodes,
# and only specified frames from trimmed episodes
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
df = df[mask].copy().reset_index(drop=True)
if len(df) == 0:
continue
# Reindex
df["index"] = range(global_index, global_index + len(df))
# Recalculate frame_index within each episode
for ep_idx in df["episode_index"].unique():
ep_mask = df["episode_index"] == ep_idx
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
# Recalculate timestamps based on frame_index and fps
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
# Determine output path (keep same structure)
rel_path = parquet_path.relative_to(src_dataset.root)
dst_path = dst_meta.root / rel_path
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, dst_meta)
global_index += len(df)
def _copy_and_reindex_videos_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
) -> None:
"""Copy video files for trimmed dataset.
In v3.0 datasets, multiple episodes are concatenated into single video files.
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
For trimming, we copy the original video files as-is and update the metadata
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
"""
for video_key in src_dataset.meta.video_keys:
video_dir = src_dataset.root / "videos" / video_key
dst_video_dir = dst_meta.root / "videos" / video_key
if not video_dir.exists():
logging.warning(f"Video directory not found: {video_dir}")
continue
# Copy all video files (they contain concatenated episodes)
# The metadata timestamps will handle which portions to use
copied_files = set()
for chunk_dir in video_dir.glob("chunk-*"):
dst_chunk_dir = dst_video_dir / chunk_dir.name
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
for video_file in chunk_dir.glob("*.mp4"):
if video_file.name not in copied_files:
dst_path = dst_chunk_dir / video_file.name
if not dst_path.exists():
shutil.copy(video_file, dst_path)
copied_files.add(video_file.name)
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
def _trim_video_frames(
src_path: Path,
dst_path: Path,
keep_frame_indices: list[int],
fps: float,
episode_start_idx: int,
) -> None:
"""Trim a video to keep only specific frames using ffmpeg."""
import subprocess
# Convert global indices to local indices within the episode
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
if not local_indices:
logging.warning(f"No frames to keep for video {src_path}")
return
# Calculate start and end times
start_frame = local_indices[0]
end_frame = local_indices[-1]
start_time = start_frame / fps
duration = (end_frame - start_frame + 1) / fps
# Use ffmpeg to trim
cmd = [
"ffmpeg", "-y",
"-ss", str(start_time),
"-i", str(src_path),
"-t", str(duration),
"-c", "copy", # Fast copy without re-encoding
str(dst_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
logging.error(f"Failed to trim video: {e.stderr.decode()}")
# Fallback: copy the whole video
shutil.copy(src_path, dst_path)
def _copy_and_reindex_episodes_metadata_for_multi_trim(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
trimmed_frame_counts: dict[int, int],
) -> None:
"""Copy and update episode metadata for trimmed dataset."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Calculate new frame counts and indices
episodes_data = []
global_idx = 0
for old_ep_idx in range(src_dataset.meta.total_episodes):
src_ep = src_dataset.meta.episodes[old_ep_idx]
if old_ep_idx in trimmed_frame_counts:
ep_length = trimmed_frame_counts[old_ep_idx]
else:
ep_length = src_ep["length"]
ep_data = {
"episode_index": old_ep_idx,
"tasks": src_ep["tasks"],
"length": ep_length,
"data/chunk_index": src_ep["data/chunk_index"],
"data/file_index": src_ep["data/file_index"],
"dataset_from_index": global_idx,
"dataset_to_index": global_idx + ep_length,
}
# Copy video metadata - preserve timestamps for concatenated videos
for video_key in src_dataset.meta.video_keys:
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
# Keep original from_timestamp (start position in concatenated video)
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
# For trimmed episodes, update to_timestamp based on new length
# For non-trimmed episodes, keep original to_timestamp
if old_ep_idx in trimmed_frame_counts:
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
else:
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
ep_data["meta/episodes/chunk_index"] = 0
ep_data["meta/episodes/file_index"] = 0
episodes_data.append(ep_data)
global_idx += ep_length
# Save episodes metadata
df = pd.DataFrame(episodes_data)
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
episodes_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(episodes_path)
# Update info.json
info = load_info(src_dataset.root)
info["total_episodes"] = len(episodes_data)
info["total_frames"] = global_idx
write_info(info, dst_meta.root)
+2 -2
View File
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Check if cached dataset contains all requested episodes
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (FileNotFoundError, NotADirectoryError):
except (AssertionError, FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.upload_folder(**upload_kwargs)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
+20 -26
View File
@@ -227,17 +227,16 @@ def decode_video_frames_torchvision(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -249,11 +248,7 @@ def decode_video_frames_torchvision(
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
assert len(timestamps) == len(closest_frames)
return closest_frames
@@ -358,16 +353,15 @@ def decode_video_frames_torchcodec(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -139,10 +139,6 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
for key in self.config.image_features:
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
__all__ = [
"ActionInterpolator",
"ActionQueue",
"LatencyTracker",
"RTCConfig",
"RTCProcessor",
]
@@ -0,0 +1,117 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
self._buffer = [action]
self._prev = action
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
+3 -1
View File
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
if self.dataset_meta is not None:
episodes_df = self.dataset_meta.episodes.to_pandas()
episodes_df = None
if self.sparse_subtask_names != ["task"]:
episodes_df = self.dataset_meta.episodes.to_pandas()
# Generate sparse targets
if self.sparse_temporal_proportions is not None:
-133
View File
@@ -104,28 +104,6 @@ Convert image dataset to video format and push to hub:
--operation.type convert_image_to_video \
--push_to_hub true
Trim single episode to keep only frames within timestamp range:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_trimmed \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--operation.end_timestamp 30.0
Trim multiple episodes at once (use null for no limit):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
Trim and re-upload to same repo (overwrites original):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--push_to_hub true
Show dataset information:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
@@ -226,32 +204,9 @@ class InfoConfig(OperationConfig):
show_features: bool = False
@dataclass
class TrimEpisodeConfig:
"""Trim episodes to keep only frames within timestamp ranges.
Supports multiple episodes via episode_trims dict:
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
Or single episode via legacy parameters:
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
"""
type: str = "trim_episode"
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
episode_trims: dict[str, list[float | None]] | None = None
# Legacy single-episode parameters (used if episode_trims is None)
episode_index: int | None = None
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
)
operation: OperationConfig
root: str | None = None
new_repo_id: str | None = None
@@ -396,92 +351,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
"""Trim episodes to keep only frames within timestamp ranges."""
if not isinstance(cfg.operation, TrimEpisodeConfig):
raise ValueError("Operation config must be TrimEpisodeConfig")
# Parse episode trims - support both multi-episode dict and legacy single episode
episode_trims: dict[int, tuple[float | None, float | None]] = {}
if cfg.operation.episode_trims is not None:
# Multi-episode mode
for ep_str, ts_range in cfg.operation.episode_trims.items():
ep_idx = int(ep_str)
start_ts = ts_range[0] if len(ts_range) > 0 else None
end_ts = ts_range[1] if len(ts_range) > 1 else None
episode_trims[ep_idx] = (start_ts, end_ts)
elif cfg.operation.episode_index is not None:
# Legacy single-episode mode
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
episode_trims[cfg.operation.episode_index] = (
cfg.operation.start_timestamp,
cfg.operation.end_timestamp,
)
else:
raise ValueError("Either episode_trims or episode_index must be specified")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
# Get episode boundaries and find frames to keep for each episode
episodes_info = dataset.meta.episodes
all_frames_to_keep: dict[int, list[int]] = {}
for ep_idx, (start_ts, end_ts) in episode_trims.items():
if ep_idx >= len(episodes_info["episode_index"]):
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
from_frame = episodes_info["dataset_from_index"][ep_idx]
to_frame = episodes_info["dataset_to_index"][ep_idx]
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
# Find frames within timestamp range
frames_to_keep = []
for frame_idx in range(from_frame, to_frame):
frame = dataset.hf_dataset[frame_idx]
ts = frame["timestamp"]
in_range = True
if start_ts is not None and ts < start_ts:
in_range = False
if end_ts is not None and ts > end_ts:
in_range = False
if in_range:
frames_to_keep.append(frame_idx)
if not frames_to_keep:
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
all_frames_to_keep[ep_idx] = frames_to_keep
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
new_dataset = trim_episodes_by_frames(
dataset,
episode_frames_to_keep=all_frames_to_keep,
output_dir=output_dir,
repo_id=output_repo_id,
)
logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ModifyTasksConfig):
raise ValueError("Operation config must be ModifyTasksConfig")
@@ -646,8 +515,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "trim_episode":
handle_trim_episode(cfg)
elif operation_type == "info":
handle_info(cfg)
else:
+69 -20
View File
@@ -74,6 +74,8 @@ from pathlib import Path
from pprint import pformat
from typing import Any
import torch
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
)
@@ -90,6 +92,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
@@ -225,6 +228,9 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
# Only applies when using a policy (not teleop)
interpolation_multiplier: int = 1
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -297,6 +303,7 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
interpolator: ActionInterpolator | None = None,
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -333,6 +340,14 @@ def record_loop(
preprocessor.reset()
postprocessor.reset()
# Reset interpolator if provided
if interpolator is not None:
interpolator.reset()
# Calculate control interval based on interpolation
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -353,24 +368,58 @@ def record_loop(
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
# With interpolation: only call policy when interpolator needs new action
if use_interpolation:
# Get action keys from robot
action_keys = sorted(robot.action_features)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
if interpolator.needs_new_action():
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
# Convert to tensor for interpolator
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
interpolator.add(action_tensor)
# Get interpolated action
interp_action = interpolator.get()
if interp_action is not None:
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
action_values = robot_action_to_send
else:
# No action available yet, skip this iteration
continue
else:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
elif policy is None and isinstance(teleop, list):
arm_action = teleop_arm.get_action()
@@ -379,6 +428,8 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
@@ -387,14 +438,6 @@ def record_loop(
)
continue
# Applies a pipeline to the action, default is IdentityProcessor
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
@@ -414,7 +457,7 @@ def record_loop(
dt_s = time.perf_counter() - start_loop_t
sleep_time_s: float = 1 / fps - dt_s
sleep_time_s: float = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
@@ -501,6 +544,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
interpolator = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
@@ -511,6 +555,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Create interpolator for smoother policy control
if cfg.interpolation_multiplier > 1:
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
robot.connect()
if teleop is not None:
@@ -542,6 +590,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
interpolator=interpolator,
display_compressed_images=display_compressed_images,
)
@@ -43,7 +43,6 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_mini,
so_leader,
)
@@ -52,7 +51,6 @@ COMPATIBLE_DEVICES = [
"koch_leader",
"omx_follower",
"omx_leader",
"openarm_mini",
"so100_follower",
"so100_leader",
"so101_follower",
-15
View File
@@ -24,7 +24,6 @@ import torch
from accelerate import Accelerator
from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
@@ -52,7 +51,6 @@ from lerobot.utils.utils import (
format_big_number,
has_method,
init_logging,
inside_slurm,
)
@@ -392,14 +390,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
progbar = tqdm(
total=cfg.steps - step,
desc="Training",
unit="step",
disable=inside_slurm(),
position=0,
leave=True,
)
logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
)
@@ -424,8 +414,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
@@ -519,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
accelerator.wait_for_everyone()
if is_main_process:
progbar.close()
if eval_env:
close_envs(eval_env)
@@ -1,30 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -1,296 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..teleoperator import Teleoperator
from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
class OpenArmMini(Teleoperator):
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
"""
config_class = OpenArmMiniConfig
name = "openarm_mini"
def __init__(self, config: OpenArmMiniConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
)
@property
def action_features(self) -> dict[str, type]:
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate:
self.calibrate()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArm Mini.
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
"""
if self.calibration:
user_input = input(
f"Press ENTER to use existing calibration for {self.id}, "
f"or type 'c' and press ENTER to run new calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper":
input(
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
range_min = int(closed_pos)
range_max = int(open_pos)
drive_mode = 0
else:
range_min = int(open_pos)
range_max = int(closed_pos)
drive_mode = 1
logger.info(
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
range_min=range_min,
range_max=range_max,
)
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None:
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action from both arms (read positions from all motors)."""
start = time.perf_counter()
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
action: dict[str, Any] = {}
for motor, val in right_positions.items():
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
@check_if_not_connected
def disconnect(self) -> None:
self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.")
-4
View File
@@ -95,10 +95,6 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .bi_openarm_leader import BiOpenArmLeader
return BiOpenArmLeader(config)
elif config.type == "openarm_mini":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
+1 -1
View File
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
)
# Check if dataset_name does not start with "eval_" but policy is provided