Compare commits

...

11 Commits

Author SHA1 Message Date
Khalil Meftah d734b0afa3 docs: reframe HIL guide around general intervention workflow 2026-02-26 11:21:21 +01:00
Khalil Meftah 2aac055fee docs: make hil data collection tutorial general-purpose 2026-02-25 17:58:04 +01:00
Steven Palma 5c6a97cc4b Merge branch 'main' into feat/RaC 2026-02-24 13:25:56 +01:00
Steven Palma c0b9621d8a Merge branch 'main' into feat/RaC 2026-02-23 15:23:24 +01:00
Steven Palma d3fb69e85b chore(examples): remove pedal 2026-02-23 11:37:29 +01:00
Steven Palma e011fcbad4 chore(style): fix pre-commit 2026-02-23 11:31:57 +01:00
Steven Palma fee1934af4 Merge branch 'main' into feat/RaC
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 11:17:37 +01:00
Pepijn 1bddfd2cc3 Merge branch 'main' into feat/RaC 2026-01-26 17:27:46 +01:00
Pepijn 5ac8fe3243 Rename to hil, use two seperta scripts one rtc one synch 2026-01-23 12:47:54 +01:00
Pepijn 467981eaef Add changes from openarms experiments 2026-01-21 16:39:53 +01:00
Pepijn 27eeff7535 Add RaC doc and example 2025-12-30 09:57:40 +01:00
9 changed files with 1533 additions and 28 deletions
+2
View File
@@ -17,6 +17,8 @@
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
- local: hil_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
title: "Tutorials"
+237
View File
@@ -0,0 +1,237 @@
# Human-In-the-Loop Data Collection
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data — recovery movements and corrections — is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
---
## Why Human-In-the-Loop?
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
- Running the trained policy on the real robot
- Having a human intervene when the robot is about to fail
- Recording the human's recovery and correction as training data
- Fine-tuning the policy on the combined dataset
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
---
## How It Works
During a HIL session, the human operator follows this loop within each episode:
1. **Watch** the policy run autonomously
2. **Pause** when failure is imminent — the robot holds its position
3. **Take control** — teleoperate the robot back to a good state (recovery), then correct the behavior
4. **Return control to the policy** — the policy resumes autonomous execution
5. Repeat steps 24 as many times as needed during the episode
6. **End the episode** when the task is complete, save and move on to the next rollout
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
This process can be repeated iteratively: deploy, collect, fine-tune, repeat — each round targeting the current policy's failure modes.
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Policy v0 (trained on demos) │
│ ↓ │
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
│ ↓ │
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
│ ↓ │
│ ... (repeat until satisfactory performance) │
└─────────────────────────────────────────────────────────────────────────┘
```
---
## Hardware Requirements
### Teleoperator Requirements
The HIL data collection scripts require **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators:**
- `so101_leader` - SO-101 Leader Arm
- `openarms_mini` - OpenArms Mini (via third-party plugin)
---
## Scripts
Two scripts are provided depending on your policy's inference speed:
| Script | Use Case | Models |
| ---------------------------- | ------------------------------------------ | --------------------- |
| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy |
| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA |
---
## Step-by-Step Guide
### Step 1: Pre-train a Base Policy
First, train a policy on your demonstration dataset:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/demo-dataset \
--policy.type=pi0 \
--output_dir=outputs/pretrain \
--batch_size=32 \
--steps=50000
```
### Step 2: Collect HIL Data
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/rac/hil_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Pick up the cube and place it in the bowl" \
--dataset.num_episodes=50
```
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
For models with high inference latency, use the RTC script for smooth execution:
```bash
python examples/rac/hil_data_collection_rtc.py \
--robot.type=so100_follower \
--teleop.type=so100_leader \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Pick up the cube" \
--rtc.execution_horizon=20 \
--interpolation=true
```
**Controls (Conceptual):**
The interaction model is:
- **Pause input**: pause autonomous policy execution
- **Takeover input**: transfer control to the human operator and record intervention data
- **Return-to-policy input**: hand control back to the policy and continue the same episode
- **Episode control inputs**: save/re-record/stop/reset as needed
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
**The HIL Protocol:**
1. Watch the policy run autonomously (teleop is idle/free)
2. When you see imminent failure, trigger the **pause input**
- Policy stops
- Teleoperator moves to match robot position (torque enabled)
- No frames recorded during pause
3. Trigger the **takeover input** to take control
- Teleoperator torque disabled, free to move
- **Recovery**: Teleoperate the robot back to a good state
- **Correction**: Correct the behavior
- All movements are recorded
4. Trigger the **return-to-policy input**
- Policy resumes autonomous execution from the current state
- You can intervene again at any time (repeat steps 24)
5. End and save the episode when the task is complete (or episode time limit is reached)
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
7. Start the next episode
**Foot Pedal Setup (Linux):**
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
```bash
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
```
### Step 3: Fine-tune the Policy
Fine-tune on the combined demonstration + HIL data:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/hil-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/hil_finetune \
--steps=20000
```
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
---
## Tips for Effective HIL Collection
### When to Intervene
Intervene when you see:
- Robot about to make an irreversible mistake
- Robot hesitating or showing uncertain behavior
- Robot deviating from the expected trajectory
### Recovery: Teleoperating Back to a Good State
During recovery, teleoperate the robot back to a state where:
- The robot is in a familiar, in-distribution configuration
- The current subtask can still be completed
- The recovery trajectory itself is informative training data
### Quality of Corrections
During correction:
- Provide **confident, clean** trajectories
- Complete the current subtask fully
- Don't overcorrect or add unnecessary movements
---
## Related Work
This HIL data collection approach builds on ideas from interactive imitation learning, including DAgger (Ross et al., 2011), HG-DAgger (Kelly et al., 2019), RaC (Hu et al., 2025), and RECAP (Physical Intelligence, 2025). See those works for a deeper treatment of the theory behind human-in-the-loop policy improvement.
```bibtex
@article{ross2011dagger,
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
year={2011}
}
@article{kelly2019hgdagger,
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
journal={arXiv preprint arXiv:1810.02890},
year={2019}
}
@article{hu2025rac,
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
journal={arXiv preprint arXiv:2509.07953},
year={2025}
}
@article{pi2025recap,
title={π0.6: a VLA That Learns From Experience},
author={Physical Intelligence},
year={2025}
}
```
+351
View File
@@ -0,0 +1,351 @@
#!/usr/bin/env python
"""
Human-in-the-Loop (HIL) Data Collection with Policy Rollout.
Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous
inference. For large models with high inference latency, use hil_data_collection_rtc.py.
The workflow:
1. Policy runs autonomously
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/hil_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/hil_dataset \
--dataset.single_task="Pick up the cube"
"""
import logging
import time
from dataclasses import dataclass
from pprint import pformat
from typing import Any
import torch
from hil_utils import (
HILDatasetConfig,
init_keyboard_listener,
make_identity_processors,
print_controls,
reset_loop,
teleop_disable_torque,
teleop_smooth_move_to,
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logger = logging.getLogger(__name__)
@dataclass
class HILConfig:
robot: RobotConfig
teleop: TeleoperatorConfig
dataset: HILDatasetConfig
policy: PreTrainedConfig | None = None
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
@safe_stop_image_writer
def rollout_loop(
robot: Robot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset: LeRobotDataset,
events: dict,
cfg: HILConfig,
):
"""Rollout loop with standard synchronous inference."""
fps = cfg.dataset.fps
device = get_safe_torch_device(cfg.device)
policy.reset()
preprocessor.reset()
postprocessor.reset()
frame_buffer = []
teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
robot_action: dict[str, Any] = {}
action_keys = sorted(robot.action_features.keys())
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(fps)
timestamp = 0
start_t = time.perf_counter()
while timestamp < cfg.dataset.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Transition to paused state
if events["policy_paused"] and not was_paused:
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# Takeover
if waiting_for_takeover and events["start_next_episode"]:
teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
if events["correction_active"]:
robot_action = teleop.get_action()
robot.send_action(robot_action)
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
elif waiting_for_takeover or events["policy_paused"]:
if last_action:
robot.send_action(last_action)
else:
# Policy execution with optional interpolation
if interpolator.needs_new_action():
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=cfg.dataset.single_task,
robot_type=robot.robot_type,
)
robot_action = make_robot_action(action_values, dataset.features)
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
interpolator.add(action_tensor)
interp_action = interpolator.get()
if interp_action is not None:
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
robot.send_action(robot_action)
last_action = robot_action
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
if cfg.display_data and robot_action:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
if (sleep_time := control_interval - dt) > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
teleop_disable_torque(teleop)
for frame in frame_buffer:
dataset.add_frame(frame)
@parser.wrap()
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
"""Main HIL data collection function."""
init_logging()
logger.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="hil_collection")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and robot.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
teleop.connect()
listener, events = init_keyboard_listener()
print_controls(rtc=False)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
)
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
hil_collect()
if __name__ == "__main__":
main()
+513
View File
@@ -0,0 +1,513 @@
#!/usr/bin/env python
"""
Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC).
Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models
(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks
asynchronously in a background thread for smooth robot control.
For fast models (ACT, Diffusion), use hil_data_collection.py instead.
The workflow:
1. Policy runs autonomously with RTC
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/hil_data_collection_rtc.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/hil_rtc_dataset \
--dataset.single_task="Pick up the cube" \
--rtc.execution_horizon=20
"""
import logging
import math
import time
from dataclasses import dataclass, field
from pprint import pformat
from threading import Event, Lock, Thread
from typing import Any
import torch
from hil_utils import (
HILDatasetConfig,
init_keyboard_listener,
make_identity_processors,
print_controls,
reset_loop,
teleop_disable_torque,
teleop_smooth_move_to,
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logger = logging.getLogger(__name__)
@dataclass
class HILRTCConfig:
robot: RobotConfig
teleop: TeleoperatorConfig
dataset: HILDatasetConfig
policy: PreTrainedConfig | None = None
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20))
interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x)
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
use_torch_compile: bool = False # First compile takes minutes, disable for real-time
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
self.rtc.enabled = True
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
class ThreadSafeRobot:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: Robot):
self._robot = robot
self._lock = Lock()
def get_observation(self) -> dict[str, Any]:
with self._lock:
return self._robot.get_observation()
def send_action(self, action: dict) -> None:
with self._lock:
self._robot.send_action(action)
@property
def observation_features(self) -> dict:
return self._robot.observation_features
@property
def action_features(self) -> dict:
return self._robot.action_features
@property
def name(self) -> str:
return self._robot.name
@property
def robot_type(self) -> str:
return self._robot.robot_type
@property
def cameras(self):
return getattr(self._robot, "cameras", {})
def rtc_inference_thread(
policy: PreTrainedPolicy,
obs_holder: dict,
hw_features: dict,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
queue_holder: dict,
shutdown_event: Event,
policy_active: Event,
cfg: HILRTCConfig,
):
"""Background thread for RTC action chunk generation."""
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.dataset.fps
threshold = 30
while not shutdown_event.is_set():
if not policy_active.is_set():
time.sleep(0.01)
continue
queue = queue_holder.get("queue")
obs = obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(0.01)
continue
if queue.qsize() <= threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
prev_actions = queue.get_left_over()
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
for name in obs_batch:
obs_batch[name] = torch.from_numpy(obs_batch[name])
if "image" in name:
obs_batch[name] = obs_batch[name].float() / 255
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device)
obs_batch["task"] = [cfg.dataset.single_task]
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
preprocessed = preprocessor(obs_batch)
actions = policy.predict_action_chunk(
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
)
original = actions.squeeze(0).clone()
processed = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
queue.merge(original, processed, new_delay, idx_before)
logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}")
except Exception as e:
logger.error(f"[RTC] Error: {e}")
time.sleep(0.5)
else:
time.sleep(0.01)
@safe_stop_image_writer
def rollout_loop(
robot: ThreadSafeRobot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset: LeRobotDataset,
events: dict,
cfg: HILRTCConfig,
queue_holder: dict,
obs_holder: dict,
policy_active: Event,
hw_features: dict,
):
"""Rollout loop with RTC for asynchronous inference."""
fps = cfg.dataset.fps
policy.reset()
preprocessor.reset()
postprocessor.reset()
frame_buffer = []
teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(fps)
robot_action: dict[str, Any] = {}
timestamp = 0
start_t = time.perf_counter()
while timestamp < cfg.dataset.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Transition to paused state
if events["policy_paused"] and not was_paused:
policy_active.clear()
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# Takeover
if waiting_for_takeover and events["start_next_episode"]:
teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
obs_holder["obs"] = obs_filtered
if events["correction_active"]:
robot_action = teleop.get_action()
robot.send_action(robot_action)
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
elif waiting_for_takeover or events["policy_paused"]:
if last_action:
robot.send_action(last_action)
else:
# Policy execution with RTC
if not policy_active.is_set():
policy_active.set()
queue = queue_holder["queue"]
if interpolator.needs_new_action():
new_action = queue.get() if queue else None
if new_action is not None:
interpolator.add(new_action.cpu())
action_tensor = interpolator.get()
if action_tensor is not None:
robot_action = {
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
}
robot.send_action(robot_action)
last_action = robot_action
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
if cfg.display_data and robot_action:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
if (sleep_time := control_interval - dt) > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
policy_active.clear()
teleop_disable_torque(teleop)
for frame in frame_buffer:
dataset.add_frame(frame)
@parser.wrap()
def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
"""Main HIL data collection function with RTC."""
init_logging()
logger.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="hil_rtc_collection")
robot_raw = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot_raw.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot_raw.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
shutdown_event = Event()
policy_active = Event()
rtc_thread = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot_raw.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
# Load policy with RTC
policy_class = get_policy_class(cfg.policy.type)
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if hasattr(policy_config, "compile_model"):
policy_config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
policy.config.rtc_config = cfg.rtc
if hasattr(policy, "init_rtc_processor"):
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot_raw.connect()
robot = ThreadSafeRobot(robot_raw)
teleop.connect()
listener, events = init_keyboard_listener()
queue_holder = {"queue": ActionQueue(cfg.rtc)}
obs_holder = {"obs": None, "robot_type": robot.robot_type}
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
rtc_thread = Thread(
target=rtc_inference_thread,
args=(
policy,
obs_holder,
hw_features,
preprocessor,
postprocessor,
queue_holder,
shutdown_event,
policy_active,
cfg,
),
daemon=True,
)
rtc_thread.start()
print_controls(rtc=True)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
queue_holder["queue"] = ActionQueue(cfg.rtc)
rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
queue_holder=queue_holder,
obs_holder=obs_holder,
policy_active=policy_active,
hw_features=hw_features,
)
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
shutdown_event.set()
policy_active.clear()
if rtc_thread and rtc_thread.is_alive():
rtc_thread.join(timeout=2.0)
if dataset:
dataset.finalize()
if robot_raw.is_connected:
robot_raw.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
hil_rtc_collect()
if __name__ == "__main__":
main()
+206
View File
@@ -0,0 +1,206 @@
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque")
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if teleop_has_motor_control(teleop):
teleop.bus.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if teleop_has_motor_control(teleop):
teleop.bus.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = int(duration_s * fps)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()})
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
print("\n[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
print("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
print("\n[HIL] ⏸ PAUSED - Press 'c' to take control")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
print("\n[HIL] ▶ Taking control...")
events["start_next_episode"] = True
elif key == keyboard.Key.right:
print("[HIL] → End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("[HIL] ← Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
print("\n" + "=" * 60)
print(" [HIL] RESET")
print("=" * 60)
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
print(" Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
print(" Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
print("\n" + "=" * 60)
print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else ""))
print("=" * 60)
print()
print(" Controls:")
print(" SPACE - Pause policy")
print(" c - Take control")
print(" → - End episode")
print(" ESC - Stop and push to hub")
print("=" * 60 + "\n")
+9 -8
View File
@@ -83,9 +83,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
@@ -151,6 +149,7 @@ class RTCDemoConfig(HubMixin):
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
@@ -351,20 +350,22 @@ def actor_control(
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
+29
View File
@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
__all__ = [
"ActionInterpolator",
"ActionQueue",
"LatencyTracker",
"RTCConfig",
"RTCProcessor",
]
@@ -0,0 +1,117 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
self._buffer = [action]
self._prev = action
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
+69 -20
View File
@@ -74,6 +74,8 @@ from pathlib import Path
from pprint import pformat
from typing import Any
import torch
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
)
@@ -90,6 +92,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
@@ -225,6 +228,9 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
# Only applies when using a policy (not teleop)
interpolation_multiplier: int = 1
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -297,6 +303,7 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
interpolator: ActionInterpolator | None = None,
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -333,6 +340,14 @@ def record_loop(
preprocessor.reset()
postprocessor.reset()
# Reset interpolator if provided
if interpolator is not None:
interpolator.reset()
# Calculate control interval based on interpolation
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -353,24 +368,58 @@ def record_loop(
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
# With interpolation: only call policy when interpolator needs new action
if use_interpolation:
# Get action keys from robot
action_keys = sorted(robot.action_features)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
if interpolator.needs_new_action():
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
# Convert to tensor for interpolator
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
interpolator.add(action_tensor)
# Get interpolated action
interp_action = interpolator.get()
if interp_action is not None:
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
action_values = robot_action_to_send
else:
# No action available yet, skip this iteration
continue
else:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
elif policy is None and isinstance(teleop, list):
arm_action = teleop_arm.get_action()
@@ -379,6 +428,8 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
@@ -387,14 +438,6 @@ def record_loop(
)
continue
# Applies a pipeline to the action, default is IdentityProcessor
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
@@ -414,7 +457,7 @@ def record_loop(
dt_s = time.perf_counter() - start_loop_t
sleep_time_s: float = 1 / fps - dt_s
sleep_time_s: float = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
@@ -501,6 +544,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
interpolator = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
@@ -511,6 +555,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Create interpolator for smoother policy control
if cfg.interpolation_multiplier > 1:
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
robot.connect()
if teleop is not None:
@@ -542,6 +590,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
interpolator=interpolator,
display_compressed_images=display_compressed_images,
)