mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d734b0afa3 | |||
| 2aac055fee | |||
| 5c6a97cc4b | |||
| c0b9621d8a | |||
| d3fb69e85b | |||
| e011fcbad4 | |||
| fee1934af4 | |||
| 1bddfd2cc3 | |||
| 5ac8fe3243 | |||
| 467981eaef | |||
| 27eeff7535 |
@@ -173,8 +173,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
title: Train RL in Simulation
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
- local: hil_collection
|
||||
title: Human In the Loop Data Collection
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
title: "Tutorials"
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
# Human-In-the-Loop Data Collection
|
||||
|
||||
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data — recovery movements and corrections — is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
|
||||
|
||||
---
|
||||
|
||||
## Why Human-In-the-Loop?
|
||||
|
||||
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
|
||||
|
||||
- Running the trained policy on the real robot
|
||||
- Having a human intervene when the robot is about to fail
|
||||
- Recording the human's recovery and correction as training data
|
||||
- Fine-tuning the policy on the combined dataset
|
||||
|
||||
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
During a HIL session, the human operator follows this loop within each episode:
|
||||
|
||||
1. **Watch** the policy run autonomously
|
||||
2. **Pause** when failure is imminent — the robot holds its position
|
||||
3. **Take control** — teleoperate the robot back to a good state (recovery), then correct the behavior
|
||||
4. **Return control to the policy** — the policy resumes autonomous execution
|
||||
5. Repeat steps 2–4 as many times as needed during the episode
|
||||
6. **End the episode** when the task is complete, save and move on to the next rollout
|
||||
|
||||
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
|
||||
|
||||
This process can be repeated iteratively: deploy, collect, fine-tune, repeat — each round targeting the current policy's failure modes.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Policy v0 (trained on demos) │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
|
||||
│ ↓ │
|
||||
│ ... (repeat until satisfactory performance) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The HIL data collection scripts require **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `so101_leader` - SO-101 Leader Arm
|
||||
- `openarms_mini` - OpenArms Mini (via third-party plugin)
|
||||
|
||||
---
|
||||
|
||||
## Scripts
|
||||
|
||||
Two scripts are provided depending on your policy's inference speed:
|
||||
|
||||
| Script | Use Case | Models |
|
||||
| ---------------------------- | ------------------------------------------ | --------------------- |
|
||||
| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy |
|
||||
| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Pre-train a Base Policy
|
||||
|
||||
First, train a policy on your demonstration dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/demo-dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=outputs/pretrain \
|
||||
--batch_size=32 \
|
||||
--steps=50000
|
||||
```
|
||||
|
||||
### Step 2: Collect HIL Data
|
||||
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/rac/hil_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.single_task="Pick up the cube and place it in the bowl" \
|
||||
--dataset.num_episodes=50
|
||||
```
|
||||
|
||||
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
|
||||
|
||||
For models with high inference latency, use the RTC script for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/rac/hil_data_collection_rtc.py \
|
||||
--robot.type=so100_follower \
|
||||
--teleop.type=so100_leader \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.single_task="Pick up the cube" \
|
||||
--rtc.execution_horizon=20 \
|
||||
--interpolation=true
|
||||
```
|
||||
|
||||
**Controls (Conceptual):**
|
||||
|
||||
The interaction model is:
|
||||
|
||||
- **Pause input**: pause autonomous policy execution
|
||||
- **Takeover input**: transfer control to the human operator and record intervention data
|
||||
- **Return-to-policy input**: hand control back to the policy and continue the same episode
|
||||
- **Episode control inputs**: save/re-record/stop/reset as needed
|
||||
|
||||
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
|
||||
|
||||
**The HIL Protocol:**
|
||||
|
||||
1. Watch the policy run autonomously (teleop is idle/free)
|
||||
2. When you see imminent failure, trigger the **pause input**
|
||||
- Policy stops
|
||||
- Teleoperator moves to match robot position (torque enabled)
|
||||
- No frames recorded during pause
|
||||
3. Trigger the **takeover input** to take control
|
||||
- Teleoperator torque disabled, free to move
|
||||
- **Recovery**: Teleoperate the robot back to a good state
|
||||
- **Correction**: Correct the behavior
|
||||
- All movements are recorded
|
||||
4. Trigger the **return-to-policy input**
|
||||
- Policy resumes autonomous execution from the current state
|
||||
- You can intervene again at any time (repeat steps 2–4)
|
||||
5. End and save the episode when the task is complete (or episode time limit is reached)
|
||||
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
|
||||
7. Start the next episode
|
||||
|
||||
**Foot Pedal Setup (Linux):**
|
||||
|
||||
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
|
||||
|
||||
```bash
|
||||
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
|
||||
```
|
||||
|
||||
### Step 3: Fine-tune the Policy
|
||||
|
||||
Fine-tune on the combined demonstration + HIL data:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/hil_finetune \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
|
||||
|
||||
---
|
||||
|
||||
## Tips for Effective HIL Collection
|
||||
|
||||
### When to Intervene
|
||||
|
||||
Intervene when you see:
|
||||
|
||||
- Robot about to make an irreversible mistake
|
||||
- Robot hesitating or showing uncertain behavior
|
||||
- Robot deviating from the expected trajectory
|
||||
|
||||
### Recovery: Teleoperating Back to a Good State
|
||||
|
||||
During recovery, teleoperate the robot back to a state where:
|
||||
|
||||
- The robot is in a familiar, in-distribution configuration
|
||||
- The current subtask can still be completed
|
||||
- The recovery trajectory itself is informative training data
|
||||
|
||||
### Quality of Corrections
|
||||
|
||||
During correction:
|
||||
|
||||
- Provide **confident, clean** trajectories
|
||||
- Complete the current subtask fully
|
||||
- Don't overcorrect or add unnecessary movements
|
||||
|
||||
---
|
||||
|
||||
## Related Work
|
||||
|
||||
This HIL data collection approach builds on ideas from interactive imitation learning, including DAgger (Ross et al., 2011), HG-DAgger (Kelly et al., 2019), RaC (Hu et al., 2025), and RECAP (Physical Intelligence, 2025). See those works for a deeper treatment of the theory behind human-in-the-loop policy improvement.
|
||||
|
||||
```bibtex
|
||||
@article{ross2011dagger,
|
||||
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
|
||||
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
|
||||
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
|
||||
year={2011}
|
||||
}
|
||||
|
||||
@article{kelly2019hgdagger,
|
||||
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
|
||||
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
|
||||
journal={arXiv preprint arXiv:1810.02890},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
@article{hu2025rac,
|
||||
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
|
||||
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
|
||||
journal={arXiv preprint arXiv:2509.07953},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{pi2025recap,
|
||||
title={π0.6: a VLA That Learns From Experience},
|
||||
author={Physical Intelligence},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Human-in-the-Loop (HIL) Data Collection with Policy Rollout.
|
||||
|
||||
Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous
|
||||
inference. For large models with high inference latency, use hil_data_collection_rtc.py.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/hil_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/hil_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILConfig:
|
||||
robot: RobotConfig
|
||||
teleop: TeleoperatorConfig
|
||||
dataset: HILDatasetConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: HILConfig,
|
||||
):
|
||||
"""Rollout loop with standard synchronous inference."""
|
||||
fps = cfg.dataset.fps
|
||||
device = get_safe_torch_device(cfg.device)
|
||||
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
robot_action: dict[str, Any] = {}
|
||||
action_keys = sorted(robot.action_features.keys())
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < cfg.dataset.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# Takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
else:
|
||||
# Policy execution with optional interpolation
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=cfg.dataset.single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action = make_robot_action(action_values, dataset.features)
|
||||
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
if cfg.display_data and robot_action:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_time := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
||||
"""Main HIL data collection function."""
|
||||
init_logging()
|
||||
logger.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="hil_collection")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
print_controls(rtc=False)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC).
|
||||
|
||||
Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models
|
||||
(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks
|
||||
asynchronously in a background thread for smooth robot control.
|
||||
|
||||
For fast models (ACT, Diffusion), use hil_data_collection.py instead.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously with RTC
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/hil_data_collection_rtc.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/hil_rtc_dataset \
|
||||
--dataset.single_task="Pick up the cube" \
|
||||
--rtc.execution_horizon=20
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pprint import pformat
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILRTCConfig:
|
||||
robot: RobotConfig
|
||||
teleop: TeleoperatorConfig
|
||||
dataset: HILDatasetConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20))
|
||||
interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
use_torch_compile: bool = False # First compile takes minutes, disable for real-time
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
self.rtc.enabled = True
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
class ThreadSafeRobot:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: Robot):
|
||||
self._robot = robot
|
||||
self._lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return self._robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self._lock:
|
||||
self._robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self._robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self._robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self._robot.robot_type
|
||||
|
||||
@property
|
||||
def cameras(self):
|
||||
return getattr(self._robot, "cameras", {})
|
||||
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy: PreTrainedPolicy,
|
||||
obs_holder: dict,
|
||||
hw_features: dict,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
queue_holder: dict,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: HILRTCConfig,
|
||||
):
|
||||
"""Background thread for RTC action chunk generation."""
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
threshold = 30
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
queue = queue_holder.get("queue")
|
||||
obs = obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
|
||||
for name in obs_batch:
|
||||
obs_batch[name] = torch.from_numpy(obs_batch[name])
|
||||
if "image" in name:
|
||||
obs_batch[name] = obs_batch[name].float() / 255
|
||||
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
|
||||
obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device)
|
||||
|
||||
obs_batch["task"] = [cfg.dataset.single_task]
|
||||
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
|
||||
|
||||
preprocessed = preprocessor(obs_batch)
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = postprocessor(actions).squeeze(0)
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
queue.merge(original, processed, new_delay, idx_before)
|
||||
logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Error: {e}")
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rollout_loop(
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: HILRTCConfig,
|
||||
queue_holder: dict,
|
||||
obs_holder: dict,
|
||||
policy_active: Event,
|
||||
hw_features: dict,
|
||||
):
|
||||
"""Rollout loop with RTC for asynchronous inference."""
|
||||
fps = cfg.dataset.fps
|
||||
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
robot_action: dict[str, Any] = {}
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < cfg.dataset.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear()
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# Takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
obs_holder["obs"] = obs_filtered
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
else:
|
||||
# Policy execution with RTC
|
||||
if not policy_active.is_set():
|
||||
policy_active.set()
|
||||
|
||||
queue = queue_holder["queue"]
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get() if queue else None
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action_tensor = interpolator.get()
|
||||
if action_tensor is not None:
|
||||
robot_action = {
|
||||
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
|
||||
}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
if cfg.display_data and robot_action:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_time := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
policy_active.clear()
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
||||
"""Main HIL data collection function with RTC."""
|
||||
init_logging()
|
||||
logger.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="hil_rtc_collection")
|
||||
|
||||
robot_raw = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot_raw.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot_raw.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
rtc_thread = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot_raw.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load policy with RTC
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot_raw.connect()
|
||||
robot = ThreadSafeRobot(robot_raw)
|
||||
teleop.connect()
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
||||
obs_holder = {"obs": None, "robot_type": robot.robot_type}
|
||||
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
|
||||
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(
|
||||
policy,
|
||||
obs_holder,
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
rtc_thread.start()
|
||||
|
||||
print_controls(rtc=True)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
||||
|
||||
rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
queue_holder=queue_holder,
|
||||
obs_holder=obs_holder,
|
||||
policy_active=policy_active,
|
||||
hw_features=hw_features,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if rtc_thread and rtc_thread.is_alive():
|
||||
rtc_thread.join(timeout=2.0)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot_raw.is_connected:
|
||||
robot_raw.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_rtc_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque")
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if teleop_has_motor_control(teleop):
|
||||
teleop.bus.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if teleop_has_motor_control(teleop):
|
||||
teleop.bus.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = int(duration_s * fps)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()})
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
print("\n[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[HIL] ⏸ PAUSED - Press 'c' to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[HIL] ▶ Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[HIL] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[HIL] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" [HIL] RESET")
|
||||
print("=" * 60)
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
print(" Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
print(" Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else ""))
|
||||
print("=" * 60)
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy")
|
||||
print(" c - Take control")
|
||||
print(" → - End episode")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 60 + "\n")
|
||||
@@ -83,9 +83,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
@@ -151,6 +149,7 @@ class RTCDemoConfig(HubMixin):
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
@@ -351,20 +350,22 @@ def actor_control(
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
|
||||
@@ -214,9 +214,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
|
||||
@@ -7,13 +7,6 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -47,7 +47,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -568,22 +567,20 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -625,10 +622,9 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -638,20 +634,21 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(frame_ranges):
|
||||
if range_idx >= len(time_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -664,7 +661,6 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -753,17 +749,15 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
@@ -1775,296 +1769,3 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def trim_episodes_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim multiple episodes to keep only specific frames.
|
||||
|
||||
This function creates a new dataset where the specified episodes contain only
|
||||
the frames at the given indices. All other episodes are copied as-is.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
|
||||
Returns:
|
||||
A new LeRobotDataset with the trimmed episodes.
|
||||
"""
|
||||
if not episode_frames_to_keep:
|
||||
raise ValueError("No episodes to trim")
|
||||
|
||||
for ep_idx in episode_frames_to_keep:
|
||||
if ep_idx >= dataset.meta.total_episodes:
|
||||
raise ValueError(f"Episode {ep_idx} does not exist")
|
||||
if not episode_frames_to_keep[ep_idx]:
|
||||
raise ValueError(f"No frames to keep for episode {ep_idx}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
|
||||
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
|
||||
|
||||
# Create new metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Build set of all frames to keep (for episodes being trimmed)
|
||||
# and compute new frame counts per episode
|
||||
all_keep_frames: set[int] = set()
|
||||
trimmed_frame_counts: dict[int, int] = {}
|
||||
for ep_idx, frames in episode_frames_to_keep.items():
|
||||
all_keep_frames.update(frames)
|
||||
trimmed_frame_counts[ep_idx] = len(frames)
|
||||
|
||||
# Copy and filter data
|
||||
_copy_and_reindex_data_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep, all_keep_frames
|
||||
)
|
||||
|
||||
# Handle videos if present
|
||||
if dataset.meta.video_keys:
|
||||
_copy_and_reindex_videos_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep
|
||||
)
|
||||
|
||||
# Copy episode metadata
|
||||
_copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
dataset, new_meta, trimmed_frame_counts
|
||||
)
|
||||
|
||||
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
|
||||
|
||||
# Return the metadata instead of trying to load as LeRobotDataset
|
||||
# This avoids Hub validation issues when the repo doesn't exist yet
|
||||
return new_meta
|
||||
|
||||
|
||||
# Keep old function for backward compatibility
|
||||
def trim_episode_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
keep_frame_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
|
||||
return trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep={episode_index: keep_frame_indices},
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
|
||||
def _copy_and_reindex_data_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
all_keep_frames: set[int],
|
||||
) -> None:
|
||||
"""Copy data files with frame-level filtering for multiple episodes."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Copy tasks
|
||||
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
|
||||
# Tasks are stored with task string as index
|
||||
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
|
||||
|
||||
# Get all parquet files
|
||||
data_dir = src_dataset.root / "data"
|
||||
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
|
||||
|
||||
trim_episode_set = set(episode_frames_to_keep.keys())
|
||||
global_index = 0
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Filter: keep all frames from non-trimmed episodes,
|
||||
# and only specified frames from trimmed episodes
|
||||
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
|
||||
df = df[mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Reindex
|
||||
df["index"] = range(global_index, global_index + len(df))
|
||||
|
||||
# Recalculate frame_index within each episode
|
||||
for ep_idx in df["episode_index"].unique():
|
||||
ep_mask = df["episode_index"] == ep_idx
|
||||
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
|
||||
|
||||
# Recalculate timestamps based on frame_index and fps
|
||||
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
|
||||
|
||||
# Determine output path (keep same structure)
|
||||
rel_path = parquet_path.relative_to(src_dataset.root)
|
||||
dst_path = dst_meta.root / rel_path
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_write_parquet(df, dst_path, dst_meta)
|
||||
global_index += len(df)
|
||||
|
||||
|
||||
def _copy_and_reindex_videos_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
) -> None:
|
||||
"""Copy video files for trimmed dataset.
|
||||
|
||||
In v3.0 datasets, multiple episodes are concatenated into single video files.
|
||||
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
|
||||
|
||||
For trimming, we copy the original video files as-is and update the metadata
|
||||
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
|
||||
"""
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
video_dir = src_dataset.root / "videos" / video_key
|
||||
dst_video_dir = dst_meta.root / "videos" / video_key
|
||||
|
||||
if not video_dir.exists():
|
||||
logging.warning(f"Video directory not found: {video_dir}")
|
||||
continue
|
||||
|
||||
# Copy all video files (they contain concatenated episodes)
|
||||
# The metadata timestamps will handle which portions to use
|
||||
copied_files = set()
|
||||
for chunk_dir in video_dir.glob("chunk-*"):
|
||||
dst_chunk_dir = dst_video_dir / chunk_dir.name
|
||||
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for video_file in chunk_dir.glob("*.mp4"):
|
||||
if video_file.name not in copied_files:
|
||||
dst_path = dst_chunk_dir / video_file.name
|
||||
if not dst_path.exists():
|
||||
shutil.copy(video_file, dst_path)
|
||||
copied_files.add(video_file.name)
|
||||
|
||||
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
|
||||
|
||||
|
||||
def _trim_video_frames(
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
keep_frame_indices: list[int],
|
||||
fps: float,
|
||||
episode_start_idx: int,
|
||||
) -> None:
|
||||
"""Trim a video to keep only specific frames using ffmpeg."""
|
||||
import subprocess
|
||||
|
||||
# Convert global indices to local indices within the episode
|
||||
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
|
||||
|
||||
if not local_indices:
|
||||
logging.warning(f"No frames to keep for video {src_path}")
|
||||
return
|
||||
|
||||
# Calculate start and end times
|
||||
start_frame = local_indices[0]
|
||||
end_frame = local_indices[-1]
|
||||
|
||||
start_time = start_frame / fps
|
||||
duration = (end_frame - start_frame + 1) / fps
|
||||
|
||||
# Use ffmpeg to trim
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(start_time),
|
||||
"-i", str(src_path),
|
||||
"-t", str(duration),
|
||||
"-c", "copy", # Fast copy without re-encoding
|
||||
str(dst_path)
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Failed to trim video: {e.stderr.decode()}")
|
||||
# Fallback: copy the whole video
|
||||
shutil.copy(src_path, dst_path)
|
||||
|
||||
|
||||
def _copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
trimmed_frame_counts: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy and update episode metadata for trimmed dataset."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Calculate new frame counts and indices
|
||||
episodes_data = []
|
||||
global_idx = 0
|
||||
|
||||
for old_ep_idx in range(src_dataset.meta.total_episodes):
|
||||
src_ep = src_dataset.meta.episodes[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_length = trimmed_frame_counts[old_ep_idx]
|
||||
else:
|
||||
ep_length = src_ep["length"]
|
||||
|
||||
ep_data = {
|
||||
"episode_index": old_ep_idx,
|
||||
"tasks": src_ep["tasks"],
|
||||
"length": ep_length,
|
||||
"data/chunk_index": src_ep["data/chunk_index"],
|
||||
"data/file_index": src_ep["data/file_index"],
|
||||
"dataset_from_index": global_idx,
|
||||
"dataset_to_index": global_idx + ep_length,
|
||||
}
|
||||
|
||||
# Copy video metadata - preserve timestamps for concatenated videos
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
|
||||
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
|
||||
|
||||
# Keep original from_timestamp (start position in concatenated video)
|
||||
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
|
||||
|
||||
# For trimmed episodes, update to_timestamp based on new length
|
||||
# For non-trimmed episodes, keep original to_timestamp
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
|
||||
else:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
|
||||
ep_data["meta/episodes/chunk_index"] = 0
|
||||
ep_data["meta/episodes/file_index"] = 0
|
||||
|
||||
episodes_data.append(ep_data)
|
||||
global_idx += ep_length
|
||||
|
||||
# Save episodes metadata
|
||||
df = pd.DataFrame(episodes_data)
|
||||
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(episodes_path)
|
||||
|
||||
# Update info.json
|
||||
info = load_info(src_dataset.root)
|
||||
info["total_episodes"] = len(episodes_data)
|
||||
info["total_frames"] = global_idx
|
||||
write_info(info, dst_meta.root)
|
||||
|
||||
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
|
||||
@@ -227,17 +227,16 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -249,11 +248,7 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -358,16 +353,15 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
|
||||
@@ -139,10 +139,6 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
|
||||
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
|
||||
+14
-5
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,7 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .openarm_mini import OpenArmMini
|
||||
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
__all__ = [
|
||||
"ActionInterpolator",
|
||||
"ActionQueue",
|
||||
"LatencyTracker",
|
||||
"RTCConfig",
|
||||
"RTCProcessor",
|
||||
]
|
||||
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
self._buffer = [action]
|
||||
self._prev = action
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -104,28 +104,6 @@ Convert image dataset to video format and push to hub:
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Trim single episode to keep only frames within timestamp range:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trimmed \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--operation.end_timestamp 30.0
|
||||
|
||||
Trim multiple episodes at once (use null for no limit):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
|
||||
|
||||
Trim and re-upload to same repo (overwrites original):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--push_to_hub true
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -226,32 +204,9 @@ class InfoConfig(OperationConfig):
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrimEpisodeConfig:
|
||||
"""Trim episodes to keep only frames within timestamp ranges.
|
||||
|
||||
Supports multiple episodes via episode_trims dict:
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
|
||||
|
||||
Or single episode via legacy parameters:
|
||||
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
|
||||
"""
|
||||
type: str = "trim_episode"
|
||||
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
|
||||
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
|
||||
episode_trims: dict[str, list[float | None]] | None = None
|
||||
# Legacy single-episode parameters (used if episode_trims is None)
|
||||
episode_index: int | None = None
|
||||
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
|
||||
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
|
||||
)
|
||||
operation: OperationConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
@@ -396,92 +351,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
|
||||
"""Trim episodes to keep only frames within timestamp ranges."""
|
||||
if not isinstance(cfg.operation, TrimEpisodeConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeConfig")
|
||||
|
||||
# Parse episode trims - support both multi-episode dict and legacy single episode
|
||||
episode_trims: dict[int, tuple[float | None, float | None]] = {}
|
||||
|
||||
if cfg.operation.episode_trims is not None:
|
||||
# Multi-episode mode
|
||||
for ep_str, ts_range in cfg.operation.episode_trims.items():
|
||||
ep_idx = int(ep_str)
|
||||
start_ts = ts_range[0] if len(ts_range) > 0 else None
|
||||
end_ts = ts_range[1] if len(ts_range) > 1 else None
|
||||
episode_trims[ep_idx] = (start_ts, end_ts)
|
||||
elif cfg.operation.episode_index is not None:
|
||||
# Legacy single-episode mode
|
||||
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
|
||||
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
|
||||
episode_trims[cfg.operation.episode_index] = (
|
||||
cfg.operation.start_timestamp,
|
||||
cfg.operation.end_timestamp,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either episode_trims or episode_index must be specified")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
|
||||
|
||||
# Get episode boundaries and find frames to keep for each episode
|
||||
episodes_info = dataset.meta.episodes
|
||||
all_frames_to_keep: dict[int, list[int]] = {}
|
||||
|
||||
for ep_idx, (start_ts, end_ts) in episode_trims.items():
|
||||
if ep_idx >= len(episodes_info["episode_index"]):
|
||||
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
|
||||
|
||||
from_frame = episodes_info["dataset_from_index"][ep_idx]
|
||||
to_frame = episodes_info["dataset_to_index"][ep_idx]
|
||||
|
||||
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
|
||||
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
|
||||
|
||||
# Find frames within timestamp range
|
||||
frames_to_keep = []
|
||||
for frame_idx in range(from_frame, to_frame):
|
||||
frame = dataset.hf_dataset[frame_idx]
|
||||
ts = frame["timestamp"]
|
||||
|
||||
in_range = True
|
||||
if start_ts is not None and ts < start_ts:
|
||||
in_range = False
|
||||
if end_ts is not None and ts > end_ts:
|
||||
in_range = False
|
||||
|
||||
if in_range:
|
||||
frames_to_keep.append(frame_idx)
|
||||
|
||||
if not frames_to_keep:
|
||||
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
|
||||
|
||||
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
|
||||
all_frames_to_keep[ep_idx] = frames_to_keep
|
||||
|
||||
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
|
||||
|
||||
new_dataset = trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep=all_frames_to_keep,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
@@ -646,8 +515,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "trim_episode":
|
||||
handle_trim_episode(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
|
||||
@@ -74,6 +74,8 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -90,6 +92,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
@@ -225,6 +228,9 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
|
||||
# Only applies when using a policy (not teleop)
|
||||
interpolation_multiplier: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -297,6 +303,7 @@ def record_loop(
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
@@ -333,6 +340,14 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset interpolator if provided
|
||||
if interpolator is not None:
|
||||
interpolator.reset()
|
||||
|
||||
# Calculate control interval based on interpolation
|
||||
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
|
||||
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -353,24 +368,58 @@ def record_loop(
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
# With interpolation: only call policy when interpolator needs new action
|
||||
if use_interpolation:
|
||||
# Get action keys from robot
|
||||
action_keys = sorted(robot.action_features)
|
||||
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
|
||||
# Convert to tensor for interpolator
|
||||
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
# Get interpolated action
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
action_values = robot_action_to_send
|
||||
else:
|
||||
# No action available yet, skip this iteration
|
||||
continue
|
||||
else:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
@@ -379,6 +428,8 @@ def record_loop(
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
@@ -387,14 +438,6 @@ def record_loop(
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
if policy is not None and act_processed_policy is not None:
|
||||
action_values = act_processed_policy
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
else:
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
@@ -414,7 +457,7 @@ def record_loop(
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
sleep_time_s: float = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
@@ -501,6 +544,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
interpolator = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
@@ -511,6 +555,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
# Create interpolator for smoother policy control
|
||||
if cfg.interpolation_multiplier > 1:
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -542,6 +590,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
interpolator=interpolator,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -52,7 +51,6 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -52,7 +51,6 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -392,14 +390,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -424,8 +414,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -519,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
|
||||
use_degrees: bool = True
|
||||
@@ -1,296 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
name = "openarm_mini"
|
||||
|
||||
def __init__(self, config: OpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
user_input = input(
|
||||
f"Press ENTER to use existing calibration for {self.id}, "
|
||||
f"or type 'c' and press ENTER to run new calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
range_min = int(closed_pos)
|
||||
range_max = int(open_pos)
|
||||
drive_mode = 0
|
||||
else:
|
||||
range_min = int(open_pos)
|
||||
range_max = int(closed_pos)
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_min,
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -95,10 +95,6 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
elif config.type == "openarm_mini":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
|
||||
Reference in New Issue
Block a user