From ca87ccd9413c59c30f524967222d2e3f1b7bb549 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 28 Apr 2026 00:57:35 +0200 Subject: [PATCH] feat(rollout): decouple policy deployment from data recording with new `lerobot-rollout` CLI (#3413) * feat(scripts): lerobot-rollout * fix(rollout) require dataset in dagger + use duration too * fix(docs): dagger num_episodes * test(rollout): fix expectations * fix(rollout): features check * fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config * docs(rollout): edit rename_map instructions * chore(rollout): multiple minor improvements * chore(rollout): address coments + minor improvements * fix(rollout): enable default * fix(tests): default value RTCConfig * fix(rollout): robot_observation_processor and notify_observation at policy frequency instead of interpolator rate Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): prevent relativeactions with sync inference engine Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): rtc reanchor to non normalized state Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): fixing the episode length to use hwc (#3469) also reducing default length to 5 minutes * feat(rollout): go back to initial position is now a config * fix(rollout): properly propagating video_files_size_in_mb to lerobot_dataset (#3470) * chore(rollout): note about dagger correction stage * chore(docs): update comments and docstring * fix(test): move rtc relative out of rollout module * fix(rollout): address the review comments --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Maxime Ellerbach --- docs/source/_toctree.yml | 2 + docs/source/hil_data_collection.mdx | 40 +- docs/source/il_robots.mdx | 131 +- docs/source/inference.mdx | 261 ++++ docs/source/rename_map.mdx | 25 +- docs/source/rtc.mdx | 10 +- docs/source/unitree_g1.mdx | 5 +- examples/hil/hil_data_collection.py | 1184 ----------------- examples/hil/hil_utils.py | 226 ---- examples/lekiwi/evaluate.py | 93 +- examples/lekiwi/record.py | 19 +- examples/lekiwi/rollout.py | 77 ++ examples/phone_to_so100/evaluate.py | 95 +- examples/phone_to_so100/record.py | 26 +- examples/phone_to_so100/rollout.py | 126 ++ examples/rtc/eval_with_real_robot.py | 673 ---------- examples/so100_to_so100_EE/evaluate.py | 95 +- examples/so100_to_so100_EE/record.py | 32 +- examples/so100_to_so100_EE/rollout.py | 134 ++ pyproject.toml | 1 + src/lerobot/configs/__init__.py | 2 + src/lerobot/configs/dataset.py | 80 ++ src/lerobot/datasets/lerobot_dataset.py | 4 + src/lerobot/policies/__init__.py | 3 +- src/lerobot/policies/rtc/__init__.py | 2 + .../policies/rtc/action_interpolator.py | 118 +- src/lerobot/policies/rtc/action_queue.py | 20 +- src/lerobot/policies/rtc/configuration_rtc.py | 2 +- src/lerobot/policies/rtc/relative.py | 58 + .../processor/relative_action_processor.py | 11 +- src/lerobot/rl/actor.py | 2 +- src/lerobot/rl/learner.py | 2 +- src/lerobot/rollout/__init__.py | 87 ++ src/lerobot/rollout/configs.py | 323 +++++ src/lerobot/rollout/context.py | 459 +++++++ src/lerobot/rollout/inference/__init__.py | 39 + src/lerobot/rollout/inference/base.py | 89 ++ src/lerobot/rollout/inference/factory.py | 128 ++ src/lerobot/rollout/inference/rtc.py | 360 +++++ src/lerobot/rollout/inference/sync.py | 122 ++ src/lerobot/rollout/ring_buffer.py | 112 ++ src/lerobot/rollout/robot_wrapper.py | 79 ++ src/lerobot/rollout/strategies/__init__.py | 36 + src/lerobot/rollout/strategies/base.py | 85 ++ src/lerobot/rollout/strategies/core.py | 304 +++++ src/lerobot/rollout/strategies/dagger.py | 767 +++++++++++ src/lerobot/rollout/strategies/factory.py | 45 + src/lerobot/rollout/strategies/highlight.py | 283 ++++ src/lerobot/rollout/strategies/sentry.py | 231 ++++ src/lerobot/scripts/lerobot_record.py | 341 ++--- src/lerobot/scripts/lerobot_rollout.py | 211 +++ src/lerobot/utils/action_interpolator.py | 116 ++ src/lerobot/utils/pedal.py | 83 ++ src/lerobot/{rl => utils}/process.py | 0 tests/datasets/test_lerobot_dataset.py | 12 + .../policies/rtc/test_action_interpolator.py | 2 +- tests/policies/rtc/test_configuration_rtc.py | 2 +- .../policies/rtc/test_rtc_relative_actions.py | 117 +- tests/test_cli_peft.py | 82 -- tests/test_control_robot.py | 3 +- tests/test_rollout.py | 345 +++++ tests/utils/test_process.py | 2 +- 62 files changed, 5577 insertions(+), 2847 deletions(-) create mode 100644 docs/source/inference.mdx delete mode 100644 examples/hil/hil_data_collection.py delete mode 100644 examples/hil/hil_utils.py create mode 100644 examples/lekiwi/rollout.py create mode 100644 examples/phone_to_so100/rollout.py delete mode 100644 examples/rtc/eval_with_real_robot.py create mode 100644 examples/so100_to_so100_EE/rollout.py create mode 100644 src/lerobot/configs/dataset.py create mode 100644 src/lerobot/policies/rtc/relative.py create mode 100644 src/lerobot/rollout/__init__.py create mode 100644 src/lerobot/rollout/configs.py create mode 100644 src/lerobot/rollout/context.py create mode 100644 src/lerobot/rollout/inference/__init__.py create mode 100644 src/lerobot/rollout/inference/base.py create mode 100644 src/lerobot/rollout/inference/factory.py create mode 100644 src/lerobot/rollout/inference/rtc.py create mode 100644 src/lerobot/rollout/inference/sync.py create mode 100644 src/lerobot/rollout/ring_buffer.py create mode 100644 src/lerobot/rollout/robot_wrapper.py create mode 100644 src/lerobot/rollout/strategies/__init__.py create mode 100644 src/lerobot/rollout/strategies/base.py create mode 100644 src/lerobot/rollout/strategies/core.py create mode 100644 src/lerobot/rollout/strategies/dagger.py create mode 100644 src/lerobot/rollout/strategies/factory.py create mode 100644 src/lerobot/rollout/strategies/highlight.py create mode 100644 src/lerobot/rollout/strategies/sentry.py create mode 100644 src/lerobot/scripts/lerobot_rollout.py create mode 100644 src/lerobot/utils/action_interpolator.py create mode 100644 src/lerobot/utils/pedal.py rename src/lerobot/{rl => utils}/process.py (100%) create mode 100644 tests/test_rollout.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f5e1129f3..01e8bfb76 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -61,6 +61,8 @@ title: SARM title: "Reward Models" - sections: + - local: inference + title: Policy Deployment (lerobot-rollout) - local: async title: Use Async Inference - local: rtc diff --git a/docs/source/hil_data_collection.mdx b/docs/source/hil_data_collection.mdx index c4839577f..ba68959d1 100644 --- a/docs/source/hil_data_collection.mdx +++ b/docs/source/hil_data_collection.mdx @@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea ### Teleoperator Requirements -The `examples/hil` HIL scripts require **teleoperators with active motors** that can: +The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can: - Enable/disable torque programmatically - Move to target positions (to mirror the robot state when pausing) -**Compatible teleoperators in the current `examples/hil` scripts:** +**Compatible teleoperators:** - `openarm_mini` - OpenArm Mini - `so_leader` - SO100 / SO101 leader arm > [!IMPORTANT] -> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`. +> The provided commands default to `bi_openarm_follower` + `openarm_mini`. > `so_follower` + `so_leader` configs are also registered and can be used via CLI flags. --- ## Script -A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`: +Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`: -| Mode | Flag | Models | -| ------------------------ | -------------------- | --------------------- | -| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy | -| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA | +| Mode | Flag | Models | +| ------------------------ | ---------------------- | --------------------- | +| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy | +| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA | --- @@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \ **Standard inference (ACT, Diffusion Policy):** ```bash -python examples/hil/hil_data_collection.py \ +lerobot-rollout --strategy.type=dagger \ --robot.type=bi_openarm_follower \ --robot.left_arm_config.port=can1 \ --robot.left_arm_config.side=left \ @@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \ --teleop.port_left=/dev/ttyACM0 \ --teleop.port_right=/dev/ttyACM1 \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ - --dataset.repo_id=your-username/hil-dataset \ + --dataset.repo_id=your-username/rollout_hil_dataset \ --dataset.single_task="Fold the T-shirt properly" \ --dataset.fps=30 \ - --dataset.episode_time_s=1000 \ - --dataset.num_episodes=50 \ + --strategy.num_episodes=50 \ --interpolation_multiplier=2 ``` @@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \ For models with high inference latency, enable RTC for smooth execution: ```bash -python examples/hil/hil_data_collection.py \ - --rtc.enabled=true \ - --rtc.execution_horizon=20 \ - --rtc.max_guidance_weight=5.0 \ - --rtc.prefix_attention_schedule=LINEAR \ +lerobot-rollout --strategy.type=dagger \ + --inference.type=rtc \ + --inference.rtc.execution_horizon=20 \ + --inference.rtc.max_guidance_weight=5.0 \ + --inference.rtc.prefix_attention_schedule=LINEAR \ --robot.type=bi_openarm_follower \ --robot.left_arm_config.port=can1 \ --robot.left_arm_config.side=left \ @@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \ --teleop.port_left=/dev/ttyACM0 \ --teleop.port_right=/dev/ttyACM1 \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ - --dataset.repo_id=your-username/hil-rtc-dataset \ + --dataset.repo_id=your-username/rollout_hil_rtc_dataset \ --dataset.single_task="Fold the T-shirt properly" \ --dataset.fps=30 \ - --dataset.episode_time_s=1000 \ - --dataset.num_episodes=50 \ + --strategy.num_episodes=50 \ --interpolation_multiplier=3 ``` @@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea - **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here. -- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`. +- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`. - **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP. diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 2356a93cc..ff0a6229e 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \ ## Run inference and evaluate your policy -You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes: +Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs: - + ```bash -lerobot-record \ +lerobot-rollout \ + --strategy.type=base \ + --policy.path=${HF_USER}/my_policy \ --robot.type=so100_follower \ --robot.port=/dev/ttyACM1 \ --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ - --robot.id=my_awesome_follower_arm \ - --display_data=false \ - --dataset.repo_id=${HF_USER}/eval_so100 \ - --dataset.single_task="Put lego brick into the transparent box" \ - --dataset.streaming_encoding=true \ - --dataset.encoder_threads=2 \ - # --dataset.vcodec=auto \ - # <- Teleop optional if you want to teleoperate in between episodes \ - # --teleop.type=so100_leader \ - # --teleop.port=/dev/ttyACM0 \ - # --teleop.id=my_awesome_leader_arm \ - --policy.path=${HF_USER}/my_policy + --task="Put lego brick into the transparent box" \ + --duration=60 ``` - - - -```python -from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.datasets import LeRobotDataset -from lerobot.utils.feature_utils import hw_to_dataset_features -from lerobot.policies.act import ACTPolicy -from lerobot.policies import make_pre_post_processors -from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig -from lerobot.scripts.lerobot_record import record_loop -from lerobot.common.control_utils import init_keyboard_listener -from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import init_rerun - - -NUM_EPISODES = 5 -FPS = 30 -EPISODE_TIME_SEC = 60 -TASK_DESCRIPTION = "My task description" -HF_MODEL_ID = "/" -HF_DATASET_ID = "/" - -# Create the robot configuration -camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} -robot_config = SO100FollowerConfig( - port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config -) - -# Initialize the robot -robot = SO100Follower(robot_config) - -# Initialize the policy -policy = ACTPolicy.from_pretrained(HF_MODEL_ID) - -# Configure the dataset features -action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") -dataset_features = {**action_features, **obs_features} - -# Create the dataset -dataset = LeRobotDataset.create( - repo_id=HF_DATASET_ID, - fps=FPS, - features=dataset_features, - robot_type=robot.name, - use_videos=True, - image_writer_threads=4, -) - -# Initialize the keyboard listener and rerun visualization -_, events = init_keyboard_listener() -init_rerun(session_name="recording") - -# Connect the robot -robot.connect() - -preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=policy, - pretrained_path=HF_MODEL_ID, - dataset_stats=dataset.meta.stats, -) - -for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - - # Run the policy inference loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - ) - - dataset.save_episode() - -# Clean up -robot.disconnect() -dataset.push_to_hub() + +```bash +lerobot-rollout \ + --strategy.type=sentry \ + --strategy.upload_every_n_episodes=5 \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM1 \ + --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ + --dataset.repo_id=${HF_USER}/eval_so100 \ + --dataset.single_task="Put lego brick into the transparent box" \ + --duration=600 ``` - - -As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: +The `--strategy.type` flag selects the execution mode: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`). -2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`). +- `base`: Autonomous rollout with no data recording (useful for quick evaluation) +- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation) +- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events) +- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection)) + +All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA). diff --git a/docs/source/inference.mdx b/docs/source/inference.mdx new file mode 100644 index 000000000..b2874d823 --- /dev/null +++ b/docs/source/inference.mdx @@ -0,0 +1,261 @@ +# Policy Deployment (lerobot-rollout) + +`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection. + +## Quick Start + +No extra dependencies are needed beyond your robot and policy extras. + +```bash +lerobot-rollout \ + --strategy.type=base \ + --policy.path=lerobot/act_koch_real \ + --robot.type=koch_follower \ + --robot.port=/dev/ttyACM0 \ + --task="pick up cube" \ + --duration=30 +``` + +This runs the policy for 30 seconds with no recording. + +--- + +## Strategies + +Select a strategy with `--strategy.type=`. Each strategy defines a different control loop with its own recording and interaction semantics. + +### Base (`--strategy.type=base`) + +Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot. + +```bash +lerobot-rollout \ + --strategy.type=base \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --task="Put lego brick into the box" \ + --duration=60 +``` + +| Flag | Description | +| ---------------- | ------------------------------------------------------ | +| `--duration` | Run time in seconds (0 = infinite) | +| `--task` | Task description passed to the policy | +| `--display_data` | Stream observations/actions to Rerun for visualization | + +### Sentry (`--strategy.type=sentry`) + +Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient. + +Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes. + +```bash +lerobot-rollout \ + --strategy.type=sentry \ + --strategy.upload_every_n_episodes=5 \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --dataset.repo_id=${HF_USER}/rollout_eval_data \ + --dataset.single_task="Put lego brick into the box" \ + --duration=3600 +``` + +| Flag | Description | +| -------------------------------------- | ----------------------------------------------------------- | +| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) | +| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) | +| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset | +| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) | + +### Highlight (`--strategy.type=highlight`) + +Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode. + +```bash +lerobot-rollout \ + --strategy.type=highlight \ + --strategy.ring_buffer_seconds=30 \ + --strategy.save_key=s \ + --strategy.push_key=h \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=koch_follower \ + --robot.port=/dev/ttyACM0 \ + --dataset.repo_id=${HF_USER}/rollout_highlight_data \ + --dataset.single_task="Pick up the red cube" +``` + +**Keyboard controls:** + +| Key | Action | +| ------------------ | -------------------------------------------------------- | +| `s` (configurable) | Start recording (flushes buffer) / stop and save episode | +| `h` (configurable) | Push dataset to Hub | +| `ESC` | Stop the session | + +| Flag | Description | +| -------------------------------------- | ---------------------------------------------- | +| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) | +| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) | +| `--strategy.save_key` | Key to toggle recording (default: `s`) | +| `--strategy.push_key` | Key to push to Hub (default: `h`) | + +### DAgger (`--strategy.type=dagger`) + +Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`). + +See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough. + +**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode. + +```bash +lerobot-rollout \ + --strategy.type=dagger \ + --strategy.num_episodes=20 \ + --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ + --robot.type=bi_openarm_follower \ + --teleop.type=openarm_mini \ + --dataset.repo_id=${HF_USER}/rollout_hil_data \ + --dataset.single_task="Fold the T-shirt" +``` + +**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry). + +```bash +lerobot-rollout \ + --strategy.type=dagger \ + --strategy.record_autonomous=true \ + --strategy.num_episodes=50 \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --teleop.type=so101_leader \ + --teleop.port=/dev/ttyACM1 \ + --dataset.repo_id=${HF_USER}/rollout_dagger_data \ + --dataset.single_task="Grasp the block" +``` + +**Keyboard controls** (default input device): + +| Key | Action | +| ------- | ------------------------------------------- | +| `Space` | Pause / resume policy execution | +| `Tab` | Start / stop human correction | +| `Enter` | Push dataset to Hub (corrections-only mode) | +| `ESC` | Stop the session | + +Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags. + +| Flag | Description | +| ------------------------------------ | ------------------------------------------------------- | +| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) | +| `--strategy.record_autonomous` | Record autonomous frames too (default: false) | +| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) | +| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) | +| `--teleop.type` | **Required.** Teleoperator type | + +--- + +## Inference Backends + +Select a backend with `--inference.type=`. All strategies work with both backends. + +### Sync (default) + +One policy call per control tick. The main loop blocks until the action is computed. + +Works with all policies. No extra flags needed. + +### Real-Time Chunking (`--inference.type=rtc`) + +A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel. + +Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency. + +```bash +lerobot-rollout \ + --strategy.type=base \ + --inference.type=rtc \ + --inference.rtc.execution_horizon=10 \ + --inference.rtc.max_guidance_weight=10.0 \ + --policy.path=${HF_USER}/pi0_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --task="Pick up the cube" \ + --duration=60 \ + --device=cuda +``` + +| Flag | Description | +| ------------------------------------------- | -------------------------------------------------------------- | +| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) | +| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) | +| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` | +| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) | + +See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters. + +--- + +## Common Flags + +| Flag | Description | Default | +| --------------------------------- | ----------------------------------------------------------------- | ------- | +| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- | +| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- | +| `--robot.port` | Serial port for the robot | -- | +| `--robot.cameras` | Camera configuration (JSON dict) | -- | +| `--fps` | Control loop frequency | 30 | +| `--duration` | Run time in seconds (0 = infinite) | 0 | +| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto | +| `--task` | Task description (used when no dataset is provided) | -- | +| `--display_data` | Stream telemetry to Rerun visualization | false | +| `--display_ip` / `--display_port` | Remote Rerun server address | -- | +| `--interpolation_multiplier` | Action interpolation factor | 1 | +| `--use_torch_compile` | Enable `torch.compile` for inference | false | +| `--resume` | Resume a previous recording session | false | +| `--play_sounds` | Vocal synthesis for events | true | + +--- + +## Programmatic Usage + +For custom deployments (e.g. with kinematics processors), use the rollout module API directly: + +```python +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies import BaseStrategy +from lerobot.utils.process import ProcessSignalHandler + +cfg = RolloutConfig( + robot=my_robot_config, + policy=my_policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=30, + duration=60, + task="my task", +) + +signal_handler = ProcessSignalHandler(use_threads=True) +ctx = build_rollout_context( + cfg, + signal_handler.shutdown_event, + robot_action_processor=my_custom_action_processor, # optional + robot_observation_processor=my_custom_obs_processor, # optional +) + +strategy = BaseStrategy(cfg.strategy) +try: + strategy.setup(ctx) + strategy.run(ctx) +finally: + strategy.teardown(ctx) +``` + +See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors. diff --git a/docs/source/rename_map.mdx b/docs/source/rename_map.mdx index 6249faaca..16ee6344a 100644 --- a/docs/source/rename_map.mdx +++ b/docs/source/rename_map.mdx @@ -61,17 +61,6 @@ lerobot-eval \ --rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}' ``` -### Recording - -`lerobot-record` also supports rename maps, nested under the dataset config: - -```bash -lerobot-record \ # When running inference - --policy.path="/smolVLA_finetuned" \ - ... \ - --dataset.rename_map='{"observation.images.glove2": "observation.images.image"}' -``` - ## Alternative: edit the policy config directly If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed. @@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset ## Quick reference -| Goal | What to do | -| ----------------------------------------- | --------------------------------------------------------------------------- | -| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` | -| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` | -| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. | -| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) | -| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source | +| Goal | What to do | +| --------------------------------------- | --------------------------------------------------------------------------- | +| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` | +| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` | +| Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. | +| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) | +| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source | diff --git a/docs/source/rtc.mdx b/docs/source/rtc.mdx index 9485d8b66..eadc34344 100644 --- a/docs/source/rtc.mdx +++ b/docs/source/rtc.mdx @@ -34,7 +34,7 @@ pip install -e ".[smolvla]" ### Using RTC with Pi0 -You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py). +You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots. The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline: ```python @@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar ## Testing RTC with a Real Robot ```bash -python examples/rtc/eval_with_real_robot.py \ +lerobot-rollout \ + --strategy.type=base \ --policy.path=${HF_USERNAME}/policy_repo_id \ + --inference.type=rtc \ + --inference.rtc.execution_horizon=10 \ + --inference.rtc.max_guidance_weight=10.0 \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58FA0834591 \ --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ @@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer() # ... create plots ``` -See `examples/rtc/eval_dataset.py` for a complete example of visualization. +See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization. ## References diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index 2e615085e..69965a561 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \ Once trained, we recommend deploying policies using inference-time RTC: ```bash -python examples/rtc/eval_with_real_robot.py \ +lerobot-rollout \ + --strategy.type=base \ --policy.path=your-username/your-repo-id \ --policy.device=cuda \ --robot.type=unitree_g1 \ @@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \ --task="task_description" \ --duration=1000 \ --fps=30 \ - --rtc.enabled=true + --inference.type=rtc ``` --- diff --git a/examples/hil/hil_data_collection.py b/examples/hil/hil_data_collection.py deleted file mode 100644 index 09a36dbe1..000000000 --- a/examples/hil/hil_data_collection.py +++ /dev/null @@ -1,1184 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Human-in-the-Loop (HIL) Data Collection with optional Real-Time Chunking (RTC). - -Implements the RaC paradigm (https://arxiv.org/abs/2509.07953) for LeRobot. By default uses synchronous -inference (best for fast models like ACT / Diffusion Policy). Set --rtc.enabled=true for -asynchronous background inference (recommended for large models like Pi0 / Pi0.5 / SmolVLA). - -The workflow: -1. Policy runs autonomously -2. Press SPACE to pause - robot holds position -3. Press 'c' to take control - human provides RECOVERY + CORRECTION -4. Press 'p' to hand control back to policy and continue recording -5. Press → to end episode (save and continue to next) -6. Reset, then do next rollout - -Keyboard Controls: - SPACE - Pause policy (robot holds position, no recording) - c - Take control (start correction, recording resumes) - p - Resume policy after pause/correction (recording continues) - → - End episode (save and continue to next) - ← - Re-record episode - ESC - Stop recording and push dataset to hub - -Usage: - # Standard synchronous inference (ACT, Diffusion Policy) - python examples/hil/hil_data_collection.py \ - --robot.type=bi_openarm_follower \ - --teleop.type=openarm_mini \ - --policy.path=path/to/pretrained_model \ - --dataset.repo_id=user/hil-dataset \ - --dataset.single_task="Fold the T-shirt properly" \ - --dataset.fps=30 \ - --interpolation_multiplier=2 - - # With RTC for large models (Pi0, Pi0.5, SmolVLA) - python examples/hil/hil_data_collection.py \ - --rtc.enabled=true \ - --rtc.execution_horizon=20 \ - --rtc.max_guidance_weight=5.0 \ - --rtc.prefix_attention_schedule=LINEAR \ - --robot.type=bi_openarm_follower \ - --teleop.type=openarm_mini \ - --policy.path=path/to/pretrained_model \ - --dataset.repo_id=user/hil-dataset \ - --dataset.single_task="Fold the T-shirt properly" \ - --dataset.fps=30 \ - --interpolation_multiplier=3 - - # RTC with bi_openarm_follower + OpenArm Mini teleop and pi0.5 policy - python examples/hil/hil_data_collection.py \ - --policy.path=lerobot-data-collection/folding_final \ - --robot.type=bi_openarm_follower \ - --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \ - --robot.left_arm_config.port=can0 \ - --robot.left_arm_config.side=left \ - --robot.left_arm_config.can_interface=socketcan \ - --robot.left_arm_config.disable_torque_on_disconnect=true \ - --robot.left_arm_config.max_relative_target=8.0 \ - --robot.right_arm_config.port=can1 \ - --robot.right_arm_config.side=right \ - --robot.right_arm_config.can_interface=socketcan \ - --robot.right_arm_config.disable_torque_on_disconnect=true \ - --robot.right_arm_config.max_relative_target=8.0 \ - --teleop.type=openarm_mini \ - --teleop.port_left=/dev/ttyACM1 \ - --teleop.port_right=/dev/ttyACM0 \ - --dataset.repo_id=lerobot-data-collection/hil_folding \ - --dataset.single_task="Fold the T-shirt properly" \ - --dataset.fps=30 \ - --dataset.num_episodes=50 \ - --rtc.enabled=true \ - --rtc.execution_horizon=20 \ - --rtc.max_guidance_weight=5.0 \ - --rtc.prefix_attention_schedule=LINEAR \ - --interpolation_multiplier=3 \ - --calibrate=true \ - --device=cuda -""" - -import logging -import math -import time -from dataclasses import dataclass, field -from pprint import pformat -from threading import Event, Lock, Thread -from typing import Any - -import torch -from hil_utils import ( - HILDatasetConfig, - init_keyboard_listener, - make_identity_processors, - print_controls, - reset_loop, - teleop_disable_torque, - teleop_smooth_move_to, -) - -from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.common.control_utils import is_headless, predict_action -from lerobot.configs import PreTrainedConfig, parser -from lerobot.datasets import ( - LeRobotDataset, - VideoEncodingManager, - aggregate_pipeline_dataset_features, - create_initial_features, - safe_stop_image_writer, -) -from lerobot.policies import PreTrainedPolicy, get_policy_class, make_policy, make_pre_post_processors -from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig -from lerobot.policies.utils import make_robot_action -from lerobot.processor import ( - NormalizerProcessorStep, - PolicyProcessorPipeline, - RelativeActionsProcessorStep, - TransitionKey, - create_transition, - rename_stats, - to_relative_actions, -) -from lerobot.robots import Robot, RobotConfig, make_robot_from_config -from lerobot.robots.bi_openarm_follower import BiOpenArmFollowerConfig -from lerobot.robots.so_follower import SOFollowerRobotConfig # noqa: F401 -from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config -from lerobot.teleoperators.openarm_mini import OpenArmMiniConfig # noqa: F401 -from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig # noqa: F401 -from lerobot.utils import get_safe_torch_device -from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR -from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features -from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.utils import init_logging, log_say -from lerobot.utils.visualization_utils import init_rerun, log_rerun_data - -logger = logging.getLogger(__name__) - - -# RTC helpers - - -class ThreadSafeRobot: - """Thread-safe wrapper for robot operations (used with RTC background thread).""" - - def __init__(self, robot: Robot): - self._robot = robot - self._lock = Lock() - - def get_observation(self) -> dict[str, Any]: - with self._lock: - return self._robot.get_observation() - - def send_action(self, action: dict) -> None: - with self._lock: - self._robot.send_action(action) - - @property - def observation_features(self) -> dict: - return self._robot.observation_features - - @property - def action_features(self) -> dict: - return self._robot.action_features - - @property - def name(self) -> str: - return self._robot.name - - @property - def robot_type(self) -> str: - return self._robot.robot_type - - @property - def cameras(self): - return getattr(self._robot, "cameras", {}) - - -def _set_openarm_max_relative_target_if_missing( - robot_cfg: RobotConfig, max_relative_target: float = 8.0 -) -> None: - """Set a safe default max_relative_target for OpenArm followers when not provided.""" - if isinstance(robot_cfg, BiOpenArmFollowerConfig): - if robot_cfg.left_arm_config.max_relative_target is None: - robot_cfg.left_arm_config.max_relative_target = max_relative_target - if robot_cfg.right_arm_config.max_relative_target is None: - robot_cfg.right_arm_config.max_relative_target = max_relative_target - - -def _reanchor_relative_rtc_prefix( - prev_actions_absolute: torch.Tensor, - current_state: torch.Tensor, - relative_step: RelativeActionsProcessorStep | None, - normalizer_step: NormalizerProcessorStep | None, - policy_device: torch.device | str, -) -> torch.Tensor: - """Convert absolute leftovers into model space for relative-action RTC policies.""" - if relative_step is None: - return prev_actions_absolute.to(policy_device) - - state = current_state.detach().cpu() - if state.dim() == 1: - state = state.unsqueeze(0) - - action_cpu = prev_actions_absolute.detach().cpu() - mask = relative_step._build_mask(action_cpu.shape[-1]) - relative_actions = to_relative_actions(action_cpu, state, mask) - - transition = create_transition(action=relative_actions) - if normalizer_step is not None: - transition = normalizer_step(transition) - - return transition[TransitionKey.ACTION].to(policy_device) - - -def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor: - """Pad/truncate RTC prefix actions to a fixed length for stable compiled inference.""" - if prev_actions.ndim != 2: - raise ValueError(f"Expected prev_actions to be 2D [T, A], got shape={tuple(prev_actions.shape)}") - - steps, action_dim = prev_actions.shape - if steps == target_steps: - return prev_actions - if steps > target_steps: - return prev_actions[:target_steps] - - padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device) - padded[:steps] = prev_actions - return padded - - -def _resolve_action_key_order(cfg, dataset_action_names: list[str]) -> list[str]: - """Choose action name ordering used to map policy tensor outputs to robot action dict.""" - policy_action_names = getattr(cfg.policy, "action_feature_names", None) - if not policy_action_names: - return dataset_action_names - - policy_action_names = list(policy_action_names) - if len(policy_action_names) != len(dataset_action_names): - logger.warning( - "[RTC] policy.action_feature_names length (%d) != dataset action dim (%d); " - "falling back to dataset order", - len(policy_action_names), - len(dataset_action_names), - ) - return dataset_action_names - - if set(dataset_action_names) != set(policy_action_names): - logger.warning( - "[RTC] policy.action_feature_names keys do not match dataset action keys; " - "falling back to dataset order" - ) - return dataset_action_names - - return policy_action_names - - -def _resolve_state_joint_order( - policy_action_names: list[str] | None, - available_joint_names: list[str], -) -> list[str]: - """Resolve joint-state ordering used to build observation.state.""" - if not policy_action_names: - return available_joint_names - - policy_action_names = list(policy_action_names) - available_set = set(available_joint_names) - policy_set = set(policy_action_names) - - if len(policy_action_names) != len(available_joint_names) or policy_set != available_set: - logger.warning( - "policy.action_feature_names does not match available state joints; " - "falling back to robot observation order" - ) - return available_joint_names - - logger.info("Using policy.action_feature_names order for observation.state mapping") - return policy_action_names - - -def _start_pedal_listener(events: dict): - """Start foot pedal listener thread if evdev is available. - - Pedal input is restricted to HIL control handoff only: - policy -> pause -> takeover -> resume policy. - Episode save/advance remains keyboard-only (right arrow). - """ - import threading - - try: - from evdev import InputDevice, categorize, ecodes - except ImportError: - logging.warning("[Pedal] evdev not installed - pedal support disabled") - return - - pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" - key_left = "KEY_A" - key_right = "KEY_C" - - def pedal_reader(): - try: - dev = InputDevice(pedal_device) - logger.info(f"[Pedal] Connected: {dev.name}") - - for ev in dev.read_loop(): - if ev.type != ecodes.EV_KEY: - continue - - key = categorize(ev) - code = key.keycode - if isinstance(code, (list, tuple)): - code = code[0] - - if key.keystate != 1: - continue - - if events["in_reset"]: - if code in [key_left, key_right]: - events["start_next_episode"] = True - else: - if code not in [key_left, key_right]: - continue - - if events["correction_active"]: - events["resume_policy"] = True - elif events["policy_paused"]: - events["start_next_episode"] = True - else: - events["policy_paused"] = True - - except FileNotFoundError: - logging.info(f"[Pedal] Device not found: {pedal_device}") - except PermissionError: - logging.warning(f"[Pedal] Permission denied for {pedal_device}") - except Exception as e: - logging.warning(f"[Pedal] Error: {e}") - - thread = threading.Thread(target=pedal_reader, daemon=True) - thread.start() - - -def _rtc_inference_thread( - policy: PreTrainedPolicy, - obs_holder: dict, - obs_lock: Lock, - hw_features: dict, - preprocessor: PolicyProcessorPipeline, - postprocessor: PolicyProcessorPipeline, - queue_holder: dict, - shutdown_event: Event, - policy_active: Event, - compile_warmup_done: Event, - cfg, -): - """Background thread for RTC action chunk generation.""" - latency_tracker = LatencyTracker() - time_per_chunk = 1.0 / cfg.dataset.fps - threshold = 30 - policy_device = policy.config.device - stats_window_start = time.perf_counter() - policy_inference_count = 0 - latency_sum_s = 0.0 - inference_count = 0 - warmup_required = max(1, int(cfg.compile_warmup_inferences)) if cfg.use_torch_compile else 0 - - relative_step = next( - ( - step - for step in preprocessor.steps - if isinstance(step, RelativeActionsProcessorStep) and step.enabled - ), - None, - ) - normalizer_step = next( - (step for step in preprocessor.steps if isinstance(step, NormalizerProcessorStep)), - None, - ) - if relative_step is not None: - if relative_step.action_names is None: - cfg_action_names = getattr(cfg.policy, "action_feature_names", None) - if cfg_action_names: - relative_step.action_names = list(cfg_action_names) - else: - fallback_action_names = obs_holder.get("action_feature_names") - if fallback_action_names: - relative_step.action_names = list(fallback_action_names) - logger.info("[RTC] Relative actions enabled: re-anchoring RTC prefix to current state") - - while not shutdown_event.is_set(): - if not policy_active.is_set(): - time.sleep(0.01) - continue - - queue = queue_holder.get("queue") - with obs_lock: - obs = obs_holder.get("obs") - if queue is None or obs is None: - time.sleep(0.01) - continue - - if queue.qsize() <= threshold: - try: - current_time = time.perf_counter() - idx_before = queue.get_action_index() - prev_actions = queue.get_left_over() - - latency = latency_tracker.max() - delay = math.ceil(latency / time_per_chunk) if latency else 0 - - obs_batch = build_dataset_frame(hw_features, obs, prefix="observation") - for name in obs_batch: - obs_batch[name] = torch.from_numpy(obs_batch[name]) - if "image" in name: - obs_batch[name] = obs_batch[name].float() / 255 - obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous() - obs_batch[name] = obs_batch[name].unsqueeze(0).to(policy_device) - - obs_batch["task"] = [cfg.dataset.single_task] - obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown") - - preprocessed = preprocessor(obs_batch) - - if prev_actions is not None and relative_step is not None and OBS_STATE in obs_batch: - prev_actions_absolute = queue.get_processed_left_over() - if prev_actions_absolute is not None and prev_actions_absolute.numel() > 0: - prev_actions = _reanchor_relative_rtc_prefix( - prev_actions_absolute=prev_actions_absolute, - current_state=obs_batch[OBS_STATE], - relative_step=relative_step, - normalizer_step=normalizer_step, - policy_device=policy_device, - ) - - if prev_actions is not None: - prev_actions = _normalize_prev_actions_length( - prev_actions, target_steps=cfg.rtc.execution_horizon - ) - - actions = policy.predict_action_chunk( - preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions - ) - - original = actions.squeeze(0).clone() - processed = postprocessor(actions).squeeze(0) - new_latency = time.perf_counter() - current_time - new_delay = math.ceil(new_latency / time_per_chunk) - inference_count += 1 - is_warmup_inference = cfg.use_torch_compile and inference_count <= warmup_required - if is_warmup_inference: - latency_tracker.reset() - else: - latency_tracker.add(new_latency) - queue.merge(original, processed, new_delay, idx_before) - policy_inference_count += 1 - latency_sum_s += new_latency - if ( - is_warmup_inference - and inference_count >= warmup_required - and not compile_warmup_done.is_set() - ): - compile_warmup_done.set() - logger.info( - "[RTC] Compile warmup complete (%d/%d inferences)", - inference_count, - warmup_required, - ) - logger.debug("[RTC] Inference latency=%.2fs, queue=%d", new_latency, queue.qsize()) - except Exception as e: - logger.error("[RTC] Error: %s", e) - time.sleep(0.5) - else: - time.sleep(0.01) - - now = time.perf_counter() - if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s: - policy_hz = policy_inference_count / window_elapsed - avg_latency_ms = ( - (latency_sum_s / policy_inference_count * 1000.0) if policy_inference_count else 0.0 - ) - logger.info( - "[HIL RTC rates] policy=%.1f Hz | avg_inference=%.1f ms | queue=%d", - policy_hz, - avg_latency_ms, - queue.qsize(), - ) - stats_window_start = now - policy_inference_count = 0 - latency_sum_s = 0.0 - - -# Config - - -@dataclass -class HILConfig: - robot: RobotConfig - teleop: TeleoperatorConfig - dataset: HILDatasetConfig - policy: PreTrainedConfig | None = None - rtc: RTCConfig = field(default_factory=RTCConfig) - interpolation_multiplier: int = 2 - record_interpolated_actions: bool = False - display_data: bool = True - play_sounds: bool = True - resume: bool = False - device: str = "cuda" - use_torch_compile: bool = False - compile_warmup_inferences: int = 2 - calibrate: bool = False - log_hz: bool = True - hz_log_interval_s: float = 2.0 - - def __post_init__(self): - policy_path = parser.get_path_arg("policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - if self.policy is None: - raise ValueError("policy.path is required") - - @classmethod - def __get_path_fields__(cls) -> list[str]: - return ["policy"] - - -# Rollout loops - - -@safe_stop_image_writer -def _rollout_sync( - robot: Robot, - teleop: Teleoperator, - policy: PreTrainedPolicy, - preprocessor: PolicyProcessorPipeline, - postprocessor: PolicyProcessorPipeline, - dataset: LeRobotDataset, - events: dict, - cfg: HILConfig, -): - """Rollout loop with standard synchronous inference.""" - fps = cfg.dataset.fps - device = get_safe_torch_device(cfg.device) - stream_online = bool(cfg.dataset.streaming_encoding) - record_stride = 1 if cfg.record_interpolated_actions else max(1, cfg.interpolation_multiplier) - - policy.reset() - preprocessor.reset() - postprocessor.reset() - - frame_buffer: list[dict] = [] - teleop_disable_torque(teleop) - - was_paused = False - waiting_for_takeover = False - last_action: dict[str, Any] | None = None - robot_action: dict[str, Any] = {} - action_keys = list(dataset.features[ACTION]["names"]) - obs_state_names = list(dataset.features[f"{OBS_STR}.state"]["names"]) - obs_image_names = [ - key.removeprefix(f"{OBS_STR}.images.") - for key in dataset.features - if key.startswith(f"{OBS_STR}.images.") - ] - - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) - control_interval = interpolator.get_control_interval(fps) - - timestamp = 0.0 - record_tick = 0 - start_t = time.perf_counter() - stats_window_start = start_t - policy_inference_count = 0 - robot_command_count = 0 - - while timestamp < cfg.dataset.episode_time_s: - loop_start = time.perf_counter() - - if events["exit_early"]: - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - events["resume_policy"] = False - break - - if events["resume_policy"] and ( - events["policy_paused"] or events["correction_active"] or waiting_for_takeover - ): - events["resume_policy"] = False - events["start_next_episode"] = False - events["policy_paused"] = False - events["correction_active"] = False - waiting_for_takeover = False - was_paused = False - last_action = None - interpolator.reset() - policy.reset() - preprocessor.reset() - postprocessor.reset() - - if events["policy_paused"] and not was_paused: - obs = robot.get_observation() - robot_pos = { - k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features - } - teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - events["start_next_episode"] = False - waiting_for_takeover = True - was_paused = True - interpolator.reset() - - if waiting_for_takeover and events["start_next_episode"]: - teleop_disable_torque(teleop) - events["start_next_episode"] = False - events["correction_active"] = True - waiting_for_takeover = False - - obs = robot.get_observation() - obs_filtered = {k: obs[k] for k in obs_state_names if k in obs} - obs_filtered.update({k: obs[k] for k in obs_image_names if k in obs}) - obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) - - if events["correction_active"]: - robot_action = teleop.get_action() - robot.send_action(robot_action) - robot_command_count += 1 - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 - - elif waiting_for_takeover or events["policy_paused"]: - if last_action: - robot.send_action(last_action) - robot_command_count += 1 - - else: - if interpolator.needs_new_action(): - action_values = predict_action( - observation=obs_frame, - policy=policy, - device=device, - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=cfg.dataset.single_task, - robot_type=robot.robot_type, - ) - policy_inference_count += 1 - robot_action = make_robot_action(action_values, dataset.features) - action_tensor = torch.tensor([robot_action[k] for k in action_keys]) - interpolator.add(action_tensor) - - interp_action = interpolator.get() - if interp_action is not None: - robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)} - robot.send_action(robot_action) - robot_command_count += 1 - last_action = robot_action - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 - - if cfg.display_data and robot_action: - log_rerun_data(observation=obs_filtered, action=robot_action) - - dt = time.perf_counter() - loop_start - if (sleep_time := control_interval - dt) > 0: - precise_sleep(sleep_time) - now = time.perf_counter() - timestamp = now - start_t - - if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s: - policy_hz = policy_inference_count / window_elapsed - robot_hz = robot_command_count / window_elapsed - logger.info( - "[HIL rates] policy=%.1f Hz (target=%.1f) | robot=%.1f Hz (target=%.1f)", - policy_hz, - fps, - robot_hz, - fps * cfg.interpolation_multiplier, - ) - stats_window_start = now - policy_inference_count = 0 - robot_command_count = 0 - - teleop_disable_torque(teleop) - - if not stream_online: - for frame in frame_buffer: - dataset.add_frame(frame) - - -@safe_stop_image_writer -def _rollout_rtc( - robot, - teleop: Teleoperator, - policy: PreTrainedPolicy, - preprocessor: PolicyProcessorPipeline, - postprocessor: PolicyProcessorPipeline, - dataset: LeRobotDataset, - events: dict, - cfg: HILConfig, - queue_holder: dict, - obs_holder: dict, - obs_lock: Lock, - policy_active: Event, - compile_warmup_done: Event, - hw_features: dict, -): - """Rollout loop with RTC for asynchronous inference.""" - fps = cfg.dataset.fps - stream_online = bool(cfg.dataset.streaming_encoding) - record_stride = 1 if cfg.record_interpolated_actions else max(1, cfg.interpolation_multiplier) - - policy.reset() - preprocessor.reset() - postprocessor.reset() - - frame_buffer: list[dict] = [] - teleop_disable_torque(teleop) - - was_paused = False - waiting_for_takeover = False - last_action: dict[str, Any] | None = None - dataset_action_keys = list(dataset.features[ACTION]["names"]) - action_keys = _resolve_action_key_order(cfg, dataset_action_keys) - if action_keys != dataset_action_keys: - logger.info("[RTC] Using policy.action_feature_names order for action tensor mapping") - else: - logger.info("[RTC] Using dataset action feature order for action tensor mapping") - obs_state_names = list(dataset.features[f"{OBS_STR}.state"]["names"]) - obs_image_names = [ - key.removeprefix(f"{OBS_STR}.images.") - for key in dataset.features - if key.startswith(f"{OBS_STR}.images.") - ] - - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) - control_interval = interpolator.get_control_interval(fps) - - robot_action: dict[str, Any] = {} - timestamp = 0.0 - start_t = time.perf_counter() - stats_window_start = start_t - robot_command_count = 0 - record_tick = 0 - obs_poll_interval = 1.0 / fps - last_obs_poll_t = 0.0 - obs_filtered: dict[str, Any] = {} - obs_frame: dict[str, Any] = {} - warmup_wait_logged = False - warmup_queue_flushed = False - - while timestamp < cfg.dataset.episode_time_s: - loop_start = time.perf_counter() - - if events["exit_early"]: - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - events["resume_policy"] = False - break - - if events["resume_policy"] and ( - events["policy_paused"] or events["correction_active"] or waiting_for_takeover - ): - events["resume_policy"] = False - events["start_next_episode"] = False - events["policy_paused"] = False - events["correction_active"] = False - waiting_for_takeover = False - was_paused = False - last_action = None - interpolator.reset() - queue_holder["queue"] = ActionQueue(cfg.rtc) - policy_active.clear() - policy.reset() - preprocessor.reset() - postprocessor.reset() - - if events["policy_paused"] and not was_paused: - policy_active.clear() - obs = robot.get_observation() - robot_pos = { - k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features - } - teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - events["start_next_episode"] = False - waiting_for_takeover = True - was_paused = True - interpolator.reset() - - if waiting_for_takeover and events["start_next_episode"]: - teleop_disable_torque(teleop) - events["start_next_episode"] = False - events["correction_active"] = True - waiting_for_takeover = False - queue_holder["queue"] = ActionQueue(cfg.rtc) - - now_for_obs = time.perf_counter() - should_poll_obs = ( - not obs_filtered - or (now_for_obs - last_obs_poll_t) >= obs_poll_interval - or events["correction_active"] - or waiting_for_takeover - or events["policy_paused"] - ) - if should_poll_obs: - obs = robot.get_observation() - obs_filtered = {k: obs[k] for k in obs_state_names if k in obs} - obs_filtered.update({k: obs[k] for k in obs_image_names if k in obs}) - obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) - with obs_lock: - obs_holder["obs"] = obs_filtered - last_obs_poll_t = now_for_obs - - if events["correction_active"]: - robot_action = teleop.get_action() - robot.send_action(robot_action) - robot_command_count += 1 - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 - - elif waiting_for_takeover or events["policy_paused"]: - if last_action: - robot.send_action(last_action) - robot_command_count += 1 - - else: - if not policy_active.is_set(): - policy_active.set() - - if cfg.use_torch_compile and not compile_warmup_done.is_set(): - if not warmup_wait_logged: - logger.info( - "[RTC] Waiting for compile warmup (%d inferences) before policy rollout", - max(1, int(cfg.compile_warmup_inferences)), - ) - warmup_wait_logged = True - else: - if cfg.use_torch_compile and not warmup_queue_flushed: - queue_holder["queue"] = ActionQueue(cfg.rtc) - interpolator.reset() - warmup_queue_flushed = True - logger.info("[RTC] Warmup queue cleared; starting live policy rollout") - - queue = queue_holder["queue"] - - if interpolator.needs_new_action(): - new_action = queue.get() if queue else None - if new_action is not None: - interpolator.add(new_action.cpu()) - - action_tensor = interpolator.get() - if action_tensor is not None: - robot_action = { - k: action_tensor[i].item() - for i, k in enumerate(action_keys) - if i < len(action_tensor) - } - robot.send_action(robot_action) - robot_command_count += 1 - last_action = robot_action - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 - - dt = time.perf_counter() - loop_start - if (sleep_time := control_interval - dt) > 0: - precise_sleep(sleep_time) - now = time.perf_counter() - timestamp = now - start_t - - if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s: - robot_hz = robot_command_count / window_elapsed - logger.info( - "[HIL RTC rates] robot=%.1f Hz (target=%.1f)", - robot_hz, - fps * cfg.interpolation_multiplier, - ) - stats_window_start = now - robot_command_count = 0 - - policy_active.clear() - teleop_disable_torque(teleop) - - if not stream_online: - for frame in frame_buffer: - dataset.add_frame(frame) - - -# Main collection function - - -@parser.wrap() -def hil_collect(cfg: HILConfig) -> LeRobotDataset: - """Main HIL data collection function (supports both sync and RTC modes).""" - init_logging() - logger.info(pformat(cfg.__dict__)) - - use_rtc = cfg.rtc.enabled - - if use_rtc: - _set_openarm_max_relative_target_if_missing(cfg.robot, max_relative_target=8.0) - - if cfg.display_data: - init_rerun(session_name="hil_collection") - - robot_raw = make_robot_from_config(cfg.robot) - teleop = make_teleoperator_from_config(cfg.teleop) - - teleop_proc, obs_proc = make_identity_processors() - - action_features_hw = {k: v for k, v in robot_raw.action_features.items() if k.endswith(".pos")} - all_observation_features = robot_raw.observation_features - available_joint_names = [ - key for key, value in all_observation_features.items() if key.endswith(".pos") and value is float - ] - ordered_joint_names = _resolve_state_joint_order( - getattr(cfg.policy, "action_feature_names", None), - available_joint_names, - ) - observation_features_hw = { - joint_name: all_observation_features[joint_name] for joint_name in ordered_joint_names - } - for key, value in all_observation_features.items(): - if isinstance(value, tuple): - observation_features_hw[key] = value - - dataset_features = combine_feature_dicts( - aggregate_pipeline_dataset_features( - pipeline=teleop_proc, - initial_features=create_initial_features(action=action_features_hw), - use_videos=cfg.dataset.video, - ), - aggregate_pipeline_dataset_features( - pipeline=obs_proc, - initial_features=create_initial_features(observation=observation_features_hw), - use_videos=cfg.dataset.video, - ), - ) - - dataset = None - listener = None - shutdown_event = Event() - policy_active = Event() - compile_warmup_done = Event() - if not cfg.use_torch_compile: - compile_warmup_done.set() - rtc_thread = None - - try: - if cfg.resume: - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - vcodec=cfg.dataset.vcodec, - streaming_encoding=cfg.dataset.streaming_encoding, - encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, - encoder_threads=cfg.dataset.encoder_threads, - ) - if hasattr(robot_raw, "cameras") and robot_raw.cameras: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras), - ) - else: - dataset = LeRobotDataset.create( - cfg.dataset.repo_id, - cfg.dataset.fps, - root=cfg.dataset.root, - robot_type=robot_raw.name, - features=dataset_features, - use_videos=cfg.dataset.video, - image_writer_processes=cfg.dataset.num_image_writer_processes, - image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera - * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []), - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - vcodec=cfg.dataset.vcodec, - streaming_encoding=cfg.dataset.streaming_encoding, - encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, - encoder_threads=cfg.dataset.encoder_threads, - ) - - # Load policy — RTC needs manual loading for predict_action_chunk support - if use_rtc: - policy_class = get_policy_class(cfg.policy.type) - policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) - if hasattr(policy_config, "compile_model"): - policy_config.compile_model = cfg.use_torch_compile - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config) - policy.config.rtc_config = cfg.rtc - if hasattr(policy, "init_rtc_processor"): - policy.init_rtc_processor() - policy = policy.to(cfg.device) - policy.eval() - else: - policy = make_policy(cfg.policy, ds_meta=dataset.meta) - - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) - - # Connect hardware - if use_rtc: - logger.info("Connecting robot (calibrate=%s)", cfg.calibrate) - robot_raw.connect(calibrate=False) - if cfg.calibrate and hasattr(robot_raw, "calibrate"): - robot_raw.calibrate() - robot_raw.disconnect() - robot_raw.connect(calibrate=False) - else: - robot_raw.connect() - - robot = ThreadSafeRobot(robot_raw) if use_rtc else robot_raw - teleop.connect() - listener, events = init_keyboard_listener() - - # RTC-specific setup - queue_holder = None - obs_holder = None - obs_lock = Lock() - hw_features = None - if use_rtc: - _start_pedal_listener(events) - queue_holder = {"queue": ActionQueue(cfg.rtc)} - obs_holder = { - "obs": None, - "robot_type": robot.robot_type, - "action_feature_names": [key for key in robot.action_features if key.endswith(".pos")], - } - hw_features = hw_to_dataset_features(observation_features_hw, "observation") - - rtc_thread = Thread( - target=_rtc_inference_thread, - args=( - policy, - obs_holder, - obs_lock, - hw_features, - preprocessor, - postprocessor, - queue_holder, - shutdown_event, - policy_active, - compile_warmup_done, - cfg, - ), - daemon=True, - ) - rtc_thread.start() - - print_controls(rtc=use_rtc) - logger.info(f" Policy: {cfg.policy.pretrained_path}") - logger.info(f" Task: {cfg.dataset.single_task}") - logger.info(f" Interpolation: {cfg.interpolation_multiplier}x") - if use_rtc: - logger.info(f" RTC: enabled (execution_horizon={cfg.rtc.execution_horizon})") - - with VideoEncodingManager(dataset): - recorded = 0 - while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds) - - if use_rtc: - queue_holder["queue"] = ActionQueue(cfg.rtc) - _rollout_rtc( - robot=robot, - teleop=teleop, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset=dataset, - events=events, - cfg=cfg, - queue_holder=queue_holder, - obs_holder=obs_holder, - obs_lock=obs_lock, - policy_active=policy_active, - compile_warmup_done=compile_warmup_done, - hw_features=hw_features, - ) - else: - _rollout_sync( - robot=robot, - teleop=teleop, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset=dataset, - events=events, - cfg=cfg, - ) - - if events["rerecord_episode"]: - log_say("Re-recording", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded += 1 - - if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - reset_loop(robot, teleop, events, cfg.dataset.fps) - - finally: - log_say("Stop recording", cfg.play_sounds, blocking=True) - - shutdown_event.set() - policy_active.clear() - - if rtc_thread and rtc_thread.is_alive(): - rtc_thread.join(timeout=2.0) - - if dataset: - dataset.finalize() - - if robot_raw.is_connected: - robot_raw.disconnect() - if teleop.is_connected: - teleop.disconnect() - - if not is_headless() and listener: - listener.stop() - - if cfg.dataset.push_to_hub and dataset is not None: - dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) - - return dataset - - -def main(): - from lerobot.utils.import_utils import register_third_party_plugins - - register_third_party_plugins() - hil_collect() - - -if __name__ == "__main__": - main() diff --git a/examples/hil/hil_utils.py b/examples/hil/hil_utils.py deleted file mode 100644 index 0f433d83a..000000000 --- a/examples/hil/hil_utils.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared utilities for Human-in-the-Loop data collection scripts.""" - -import logging -import time -from dataclasses import dataclass, field -from pathlib import Path - -from lerobot.common.control_utils import is_headless -from lerobot.processor import ( - IdentityProcessorStep, - RobotAction, - RobotObservation, - RobotProcessorPipeline, - observation_to_transition, - robot_action_observation_to_transition, - transition_to_observation, - transition_to_robot_action, -) -from lerobot.robots import Robot -from lerobot.teleoperators import Teleoperator -from lerobot.utils.robot_utils import precise_sleep - -logger = logging.getLogger(__name__) - - -@dataclass -class HILDatasetConfig: - repo_id: str - single_task: str - root: str | Path | None = None - fps: int = 30 - episode_time_s: float = 120 - num_episodes: int = 50 - video: bool = True - push_to_hub: bool = True - private: bool = False - tags: list[str] | None = None - num_image_writer_processes: int = 0 - num_image_writer_threads_per_camera: int = 4 - video_encoding_batch_size: int = 1 - vcodec: str = "auto" - streaming_encoding: bool = True - encoder_queue_maxsize: int = 30 - encoder_threads: int | None = None - rename_map: dict[str, str] = field(default_factory=dict) - - -def teleop_has_motor_control(teleop: Teleoperator) -> bool: - """Check if teleoperator has motor control capabilities.""" - return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions")) - - -def teleop_disable_torque(teleop: Teleoperator) -> None: - """Disable teleop torque if supported.""" - if hasattr(teleop, "disable_torque"): - teleop.disable_torque() - - -def teleop_enable_torque(teleop: Teleoperator) -> None: - """Enable teleop torque if supported.""" - if hasattr(teleop, "enable_torque"): - teleop.enable_torque() - - -def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50): - """Smoothly move teleop to target position if motor control is available.""" - if not teleop_has_motor_control(teleop): - logger.warning("Teleop does not support motor control - cannot mirror robot position") - return - - teleop_enable_torque(teleop) - current = teleop.get_action() - steps = max(int(duration_s * fps), 1) - - for step in range(steps + 1): - t = step / steps - interp = {} - for k in current: - if k in target_pos: - interp[k] = current[k] * (1 - t) + target_pos[k] * t - else: - interp[k] = current[k] - teleop.write_goal_positions(interp) - time.sleep(1 / fps) - - -def init_keyboard_listener(): - """Initialize keyboard listener with HIL controls.""" - events = { - "exit_early": False, - "rerecord_episode": False, - "stop_recording": False, - "policy_paused": False, - "correction_active": False, - "resume_policy": False, - "in_reset": False, - "start_next_episode": False, - } - - if is_headless(): - logger.warning("Headless environment - keyboard controls unavailable") - return None, events - - from pynput import keyboard - - def on_press(key): - try: - if events["in_reset"]: - if key in [keyboard.Key.space, keyboard.Key.right]: - logger.info("[HIL] Starting next episode...") - events["start_next_episode"] = True - elif hasattr(key, "char") and key.char == "c": - events["start_next_episode"] = True - elif key == keyboard.Key.esc: - logger.info("[HIL] ESC - Stop recording, pushing to hub...") - events["stop_recording"] = True - events["start_next_episode"] = True - else: - if key == keyboard.Key.space: - if not events["policy_paused"] and not events["correction_active"]: - logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy") - events["policy_paused"] = True - elif hasattr(key, "char") and key.char == "c": - if events["policy_paused"] and not events["correction_active"]: - logger.info("[HIL] Taking control...") - events["start_next_episode"] = True - elif hasattr(key, "char") and key.char == "p": - if events["policy_paused"] or events["correction_active"]: - logger.info("[HIL] Resuming policy...") - events["resume_policy"] = True - elif key == keyboard.Key.right: - logger.info("[HIL] End episode") - events["exit_early"] = True - elif key == keyboard.Key.left: - logger.info("[HIL] Re-record episode") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - logger.info("[HIL] ESC - Stop recording...") - events["stop_recording"] = True - events["exit_early"] = True - except Exception as e: - logger.info(f"Key error: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - return listener, events - - -def make_identity_processors(): - """Create identity processors for recording.""" - teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( - steps=[IdentityProcessorStep()], - to_transition=robot_action_observation_to_transition, - to_output=transition_to_robot_action, - ) - obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation]( - steps=[IdentityProcessorStep()], - to_transition=observation_to_transition, - to_output=transition_to_observation, - ) - return teleop_proc, obs_proc - - -def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int): - """Reset period where human repositions environment.""" - logger.info("[HIL] RESET") - - events["in_reset"] = True - events["start_next_episode"] = False - - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} - teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - - logger.info("Press any key to enable teleoperation") - while not events["start_next_episode"] and not events["stop_recording"]: - precise_sleep(0.05) - - if events["stop_recording"]: - return - - events["start_next_episode"] = False - teleop_disable_torque(teleop) - logger.info("Teleop enabled - press any key to start episode") - - while not events["start_next_episode"] and not events["stop_recording"]: - loop_start = time.perf_counter() - action = teleop.get_action() - robot.send_action(action) - precise_sleep(1 / fps - (time.perf_counter() - loop_start)) - - events["in_reset"] = False - events["start_next_episode"] = False - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - events["resume_policy"] = False - - -def print_controls(rtc: bool = False): - """Print control instructions.""" - mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "") - logger.info( - "%s\n Controls:\n" - " SPACE - Pause policy\n" - " c - Take control\n" - " p - Resume policy after pause/correction\n" - " → - End episode\n" - " ESC - Stop and push to hub", - mode, - ) diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index d8c53829e..3ddcd1f14 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -14,17 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.common.control_utils import init_keyboard_listener +import logging +import time + +from lerobot.common.control_utils import init_keyboard_listener, predict_action from lerobot.datasets import LeRobotDataset from lerobot.policies import make_pre_post_processors from lerobot.policies.act import ACTPolicy +from lerobot.policies.utils import make_robot_action from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig -from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.feature_utils import hw_to_dataset_features +from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features +from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import init_rerun +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data NUM_EPISODES = 2 FPS = 30 @@ -35,6 +39,9 @@ HF_DATASET_ID = "/" def main(): + # NOTE: For production policy deployment, use `lerobot-rollout` CLI instead. + # This script provides a self-contained example for educational purposes. + # Create the robot configuration & robot robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") @@ -83,43 +90,67 @@ def main(): raise ValueError("Robot is not connected!") print("Starting evaluate loop...") + control_interval = 1 / FPS recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) + # Inline evaluation loop: predict actions and send to robot + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < EPISODE_TIME_SEC: + start_loop_t = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + break + + # Get robot observation + obs = robot.get_observation() + obs_processed = robot_observation_processor(obs) + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) + + # Predict action using the policy + action_tensor = predict_action( + observation=observation_frame, + policy=policy, + device=policy.config.device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.device.type == "cuda", + task=TASK_DESCRIPTION, + robot_type=robot.name, + ) + + # Convert policy output to robot action dict + action_values = make_robot_action(action_tensor, dataset.features) + + # Process and send action to robot + robot_action_to_send = robot_action_processor((action_values, obs)) + robot.send_action(robot_action_to_send) + + # Write to dataset + action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) + frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION} + dataset.add_frame(frame) + + log_rerun_data(observation=obs_processed, action=action_values) + + dt_s = time.perf_counter() - start_loop_t + sleep_time_s = control_interval - dt_s + if sleep_time_s < 0: + logging.warning( + f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)." + ) + precise_sleep(max(sleep_time_s, 0.0)) + timestamp = time.perf_counter() - start_episode_t # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] ): log_say("Reset the environment") - record_loop( - robot=robot, - events=events, - fps=FPS, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) + log_say("Waiting for environment reset, press right arrow key when ready...") if events["rerecord_episode"]: log_say("Re-record episode") diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index de5df7756..2c581f5ff 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -45,9 +45,6 @@ def main(): leader_arm = SO100Leader(leader_arm_config) keyboard = KeyboardTeleop(keyboard_config) - # TODO(Steven): Update this example to use pipelines - teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() - # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, ACTION) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) @@ -77,6 +74,10 @@ def main(): if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: raise ValueError("Robot or teleop is not connected!") + teleop_action_processor, robot_action_processor, robot_observation_processor = ( + make_default_processors() + ) + print("Starting record loop...") recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: @@ -87,14 +88,14 @@ def main(): robot=robot, events=events, fps=FPS, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, dataset=dataset, teleop=[leader_arm, keyboard], control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, ) # Reset the environment if not stopping or re-recording @@ -106,13 +107,13 @@ def main(): robot=robot, events=events, fps=FPS, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=[leader_arm, keyboard], control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, ) if events["rerecord_episode"]: diff --git a/examples/lekiwi/rollout.py b/examples/lekiwi/rollout.py new file mode 100644 index 000000000..4fb103c8c --- /dev/null +++ b/examples/lekiwi/rollout.py @@ -0,0 +1,77 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a trained policy on LeKiwi without recording (base rollout). + +Uses the rollout engine's :class:`BaseStrategy` (autonomous execution, +no dataset) with :class:`SyncInferenceConfig` (inline policy call per +control tick). For a CLI entry point with the same capabilities plus +recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``. +""" + +from lerobot.configs import PreTrainedConfig +from lerobot.robots.lekiwi import LeKiwiClientConfig +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies import BaseStrategy +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging + +FPS = 30 +DURATION_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" + + +def main(): + init_logging() + + # Robot: LeKiwi client — make sure lekiwi_host is already running on the robot. + robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") + + # Policy: load the pretrained config. ``pretrained_path`` is read downstream + # by ``build_rollout_context`` to reload the full model. + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_config.pretrained_path = HF_MODEL_ID + + # Assemble the rollout config: base strategy (no recording) + sync inference. + cfg = RolloutConfig( + robot=robot_config, + policy=policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=FPS, + duration=DURATION_SEC, + task=TASK_DESCRIPTION, + ) + + # Graceful Ctrl-C: the strategy loop exits when shutdown_event is set. + signal_handler = ProcessSignalHandler(use_threads=True) + + # Build the context (connects robot, loads policy, wires the inference strategy). + # No custom processors here — LeKiwi runs on raw joint features. + ctx = build_rollout_context(cfg, signal_handler.shutdown_event) + + strategy = BaseStrategy(cfg.strategy) + try: + strategy.setup(ctx) + strategy.run(ctx) + finally: + strategy.teardown(ctx) + + +if __name__ == "__main__": + main() diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 267e67c48..e859123d0 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -14,13 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import time + from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.common.control_utils import init_keyboard_listener +from lerobot.common.control_utils import init_keyboard_listener, predict_action from lerobot.configs import FeatureType, PolicyFeature from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.model.kinematics import RobotKinematics from lerobot.policies import make_pre_post_processors from lerobot.policies.act import ACTPolicy +from lerobot.policies.utils import make_robot_action from lerobot.processor import ( RobotProcessorPipeline, make_default_teleop_action_processor, @@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.scripts.lerobot_record import record_loop from lerobot.types import RobotAction, RobotObservation -from lerobot.utils.feature_utils import combine_feature_dicts +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts +from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import init_rerun +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data NUM_EPISODES = 5 FPS = 30 @@ -49,6 +54,9 @@ HF_DATASET_ID = "/" def main(): + # NOTE: For production policy deployment, use `lerobot-rollout` CLI instead. + # This script provides a self-contained example for educational purposes. + # Create the robot configuration & robot camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} robot_config = SO100FollowerConfig( @@ -143,43 +151,67 @@ def main(): raise ValueError("Robot is not connected!") print("Starting evaluate loop...") + control_interval = 1 / FPS episode_idx = 0 for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) + # Inline evaluation loop: predict actions and send to robot + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < EPISODE_TIME_SEC: + start_loop_t = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + break + + # Get robot observation + obs = robot.get_observation() + obs_processed = robot_joints_to_ee_pose_processor(obs) + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) + + # Predict action using the policy + action_tensor = predict_action( + observation=observation_frame, + policy=policy, + device=policy.config.device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.device.type == "cuda", + task=TASK_DESCRIPTION, + robot_type=robot.name, + ) + + # Convert policy output to robot action dict + action_values = make_robot_action(action_tensor, dataset.features) + + # Process and send action to robot (EE -> joints via IK) + robot_action_to_send = robot_ee_to_joints_processor((action_values, obs)) + robot.send_action(robot_action_to_send) + + # Write to dataset + action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) + frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION} + dataset.add_frame(frame) + + log_rerun_data(observation=obs_processed, action=action_values) + + dt_s = time.perf_counter() - start_loop_t + sleep_time_s = control_interval - dt_s + if sleep_time_s < 0: + logging.warning( + f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)." + ) + precise_sleep(max(sleep_time_s, 0.0)) + timestamp = time.perf_counter() - start_episode_t # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] ): log_say("Reset the environment") - record_loop( - robot=robot, - events=events, - fps=FPS, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) + log_say("Waiting for environment reset, press right arrow key when ready...") if events["rerecord_episode"]: log_say("Re-record episode") @@ -190,7 +222,6 @@ def main(): # Save episode dataset.save_episode() - episode_idx += 1 finally: # Clean up log_say("Stop recording") diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 6a8d38ec3..87b8e49fd 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -65,14 +65,15 @@ def main(): robot = SO100Follower(robot_config) phone = Phone(teleop_config) - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf + # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: + # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf kinematics_solver = RobotKinematics( urdf_path="./SO101/so101_new_calib.urdf", target_frame_name="gripper_frame_link", joint_names=list(robot.bus.motors.keys()), ) - # Build pipeline to convert phone action to EE action + # Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint). phone_to_robot_ee_pose_processor = RobotProcessorPipeline[ tuple[RobotAction, RobotObservation], RobotAction ]( @@ -94,7 +95,7 @@ def main(): to_output=transition_to_robot_action, ) - # Build pipeline to convert EE action to joints action + # Build pipeline to convert EE action to joints action (IK). robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( steps=[ InverseKinematicsEEToJoints( @@ -107,7 +108,7 @@ def main(): to_output=transition_to_robot_action, ) - # Build pipeline to convert joint observation to EE observation + # Build pipeline to convert joint observation to EE observation (FK). robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( steps=[ ForwardKinematicsJointsToEE( @@ -118,13 +119,12 @@ def main(): to_output=transition_to_observation, ) - # Create the dataset + # Create the dataset, deriving features from the pipelines so the on-disk schema + # matches exactly what the pipelines produce at runtime. dataset = LeRobotDataset.create( repo_id=HF_REPO_ID, fps=FPS, features=combine_feature_dicts( - # Run the feature contract of the pipelines - # This tells you how the features would look like after the pipeline steps aggregate_pipeline_dataset_features( pipeline=phone_to_robot_ee_pose_processor, initial_features=create_initial_features(action=phone.action_features), @@ -163,14 +163,14 @@ def main(): robot=robot, events=events, fps=FPS, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, teleop=phone, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=phone_to_robot_ee_pose_processor, - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose, ) # Reset the environment if not stopping or re-recording @@ -182,13 +182,13 @@ def main(): robot=robot, events=events, fps=FPS, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, teleop=phone, control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=phone_to_robot_ee_pose_processor, - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose, ) if events["rerecord_episode"]: diff --git a/examples/phone_to_so100/rollout.py b/examples/phone_to_so100/rollout.py new file mode 100644 index 000000000..ca6706c52 --- /dev/null +++ b/examples/phone_to_so100/rollout.py @@ -0,0 +1,126 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a trained EE-space policy on SO100 (phone-trained) without recording. + +Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained +with phone teleoperation in EE space, so at deployment we only need the +joint↔EE conversion on the robot side; the phone is not used. + +Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig` +(inline policy call). For recording during rollout, switch to Sentry, +Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``. +""" + +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.configs import PreTrainedConfig +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + RobotProcessorPipeline, + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies import BaseStrategy +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging + +FPS = 30 +DURATION_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" + + +def main(): + init_logging() + + camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} + robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, + ) + + # Peek at motor names once to build the kinematic solver. + temp_robot = SO100Follower(robot_config) + motor_names = list(temp_robot.bus.motors.keys()) + + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=motor_names, + ) + + robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + + robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=motor_names, + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_config.pretrained_path = HF_MODEL_ID + + cfg = RolloutConfig( + robot=robot_config, + policy=policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=FPS, + duration=DURATION_SEC, + task=TASK_DESCRIPTION, + ) + + signal_handler = ProcessSignalHandler(use_threads=True) + + ctx = build_rollout_context( + cfg, + signal_handler.shutdown_event, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + strategy = BaseStrategy(cfg.strategy) + try: + strategy.setup(ctx) + strategy.run(ctx) + finally: + strategy.teardown(ctx) + + +if __name__ == "__main__": + main() diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py deleted file mode 100644 index 66562749c..000000000 --- a/examples/rtc/eval_with_real_robot.py +++ /dev/null @@ -1,673 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots. - -This script demonstrates: -1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC -2. Consuming actions from the policy while the robot executes -3. Periodically requesting new action chunks in the background using threads -4. Managing action buffers and timing for real-time operation - -For simulation environments, see eval_with_simulation.py - -Usage: - # Run RTC with Real robot with RTC - uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=/smolvla_check_rtc_last3 \ - --policy.device=mps \ - --rtc.enabled=true \ - --rtc.execution_horizon=20 \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58FA0834591 \ - --robot.id=so100_follower \ - --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ - --task="Move green small object into the purple platform" \ - --duration=120 - - # Run RTC with Real robot without RTC - uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=/smolvla_check_rtc_last3 \ - --policy.device=mps \ - --rtc.enabled=false \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58FA0834591 \ - --robot.id=so100_follower \ - --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ - --task="Move green small object into the purple platform" \ - --duration=120 - - # Run RTC with Real robot with pi0.5 policy - uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=/pi05_check_rtc \ - --policy.device=mps \ - --rtc.enabled=true \ - --rtc.execution_horizon=20 \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58FA0834591 \ - --robot.id=so100_follower \ - --robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \ - --task="Move green small object into the purple platform" \ - --duration=120 - - # Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy - python examples/rtc/eval_with_real_robot.py \ - --policy.path=lerobot-data-collection/folding_final \ - --robot.type=bi_openarm_follower \ - --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \ - --robot.left_arm_config.port=can0 \ - --robot.left_arm_config.side=left \ - --robot.left_arm_config.can_interface=socketcan \ - --robot.left_arm_config.disable_torque_on_disconnect=true \ - --robot.left_arm_config.max_relative_target=8.0 \ - --robot.right_arm_config.port=can1 \ - --robot.right_arm_config.side=right \ - --robot.right_arm_config.can_interface=socketcan \ - --robot.right_arm_config.disable_torque_on_disconnect=true \ - --robot.right_arm_config.max_relative_target=8.0 \ - --task="Fold the T-shirt properly" \ - --fps=30 \ - --duration=2000 \ - --interpolation_multiplier=3 \ - --rtc.enabled=true \ - --rtc.execution_horizon=20 \ - --rtc.max_guidance_weight=5.0 \ - --rtc.prefix_attention_schedule=LINEAR \ - --device=cuda -""" - -import logging -import math -import sys -import time -import traceback -from dataclasses import dataclass, field -from threading import Event, Lock, Thread - -import torch -from torch import Tensor - -from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 -from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser -from lerobot.policies import get_policy_class, make_pre_post_processors -from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig -from lerobot.processor import ( - NormalizerProcessorStep, - RelativeActionsProcessorStep, - TransitionKey, - create_transition, - make_default_robot_action_processor, - make_default_robot_observation_processor, - to_relative_actions, -) -from lerobot.rl.process import ProcessSignalHandler -from lerobot.robots import ( # noqa: F401 - Robot, - RobotConfig, - bi_openarm_follower, - bi_so_follower, - koch_follower, - so_follower, - unitree_g1, -) -from lerobot.robots.utils import make_robot_from_config -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE -from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features -from lerobot.utils.hub import HubMixin -from lerobot.utils.utils import init_logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class RobotWrapper: - def __init__(self, robot: Robot): - self.robot = robot - self.lock = Lock() - - def get_observation(self) -> dict[str, Tensor]: - with self.lock: - return self.robot.get_observation() - - def send_action(self, action: Tensor): - with self.lock: - self.robot.send_action(action) - - def observation_features(self) -> list[str]: - with self.lock: - return self.robot.observation_features - - def action_features(self) -> list[str]: - with self.lock: - return self.robot.action_features - - -@dataclass -class RTCDemoConfig(HubMixin): - """Configuration for RTC demo with action chunking policies and real robots.""" - - # Policy configuration - policy: PreTrainedConfig | None = None - - # Robot configuration - robot: RobotConfig | None = None - - # RTC configuration - rtc: RTCConfig = field( - default_factory=lambda: RTCConfig( - execution_horizon=10, - max_guidance_weight=1.0, - prefix_attention_schedule=RTCAttentionSchedule.EXP, - ) - ) - - # Demo parameters - duration: float = 30.0 # Duration to run the demo (seconds) - fps: float = 10.0 # Action execution frequency (Hz) - interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x) - - # Compute device - device: str | None = None # Device to run on (cuda, cpu, auto) - - # Get new actions horizon. The amount of executed steps after which will be requested new actions. - # It should be higher than inference delay + execution horizon. - action_queue_size_to_get_new_actions: int = 30 - - # Task to execute - task: str = field(default="", metadata={"help": "Task to execute"}) - - # Torch compile configuration - use_torch_compile: bool = field( - default=False, - metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"}, - ) - - torch_compile_backend: str = field( - default="inductor", - metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"}, - ) - - torch_compile_mode: str = field( - default="default", - metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"}, - ) - - torch_compile_disable_cudagraphs: bool = field( - default=True, - metadata={ - "help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor " - "operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues." - }, - ) - - def __post_init__(self): - # HACK: We parse again the cli args here to get the pretrained path if there was one. - policy_path = parser.get_path_arg("policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - else: - raise ValueError("Policy path is required") - - # Validate that robot configuration is provided - if self.robot is None: - raise ValueError("Robot configuration must be provided") - - @classmethod - def __get_path_fields__(cls) -> list[str]: - """This enables the parser to load config from the policy using `--policy.path=local/dir`""" - return ["policy"] - - -def is_image_key(k: str) -> bool: - return k.startswith(OBS_IMAGES) - - -def _reanchor_relative_rtc_prefix( - prev_actions_absolute: Tensor, - current_state: Tensor, - relative_step: RelativeActionsProcessorStep, - normalizer_step: NormalizerProcessorStep | None, - policy_device: torch.device | str, -) -> Tensor: - """Convert absolute leftovers into model-space for relative-action RTC policies. - - When a policy uses relative actions, the RTC prefix (leftover actions from - the previous chunk) is stored in absolute space. Before feeding it back to - the policy we need to re-express it relative to the *current* robot state - and then re-normalize. - """ - state = current_state.detach().cpu() - if state.dim() == 1: - state = state.unsqueeze(0) - - action_cpu = prev_actions_absolute.detach().cpu() - mask = relative_step._build_mask(action_cpu.shape[-1]) - relative_actions = to_relative_actions(action_cpu, state, mask) - - transition = create_transition(action=relative_actions) - if normalizer_step is not None: - transition = normalizer_step(transition) - - return transition[TransitionKey.ACTION].to(policy_device) - - -def get_actions( - policy, - robot: RobotWrapper, - robot_observation_processor, - action_queue: ActionQueue, - shutdown_event: Event, - cfg: RTCDemoConfig, -): - """Thread function to request action chunks from the policy. - - Args: - policy: The policy instance (SmolVLA, Pi0, etc.) - robot: The robot instance for getting observations - robot_observation_processor: Processor for raw robot observations - action_queue: Queue to put new action chunks - shutdown_event: Event to signal shutdown - cfg: Demo configuration - """ - try: - logger.info("[GET_ACTIONS] Starting get actions thread") - - latency_tracker = LatencyTracker() # Track latency of action chunks - fps = cfg.fps - time_per_chunk = 1.0 / fps - - # Only keep .pos joints + camera streams if the policy was trained on positions, - # not the full pos/vel/torque state the robot exposes. - observation_features_hw = { - key: value - for key, value in robot.observation_features().items() - if key.endswith(".pos") or isinstance(value, tuple) - } - - dataset_features = hw_to_dataset_features(observation_features_hw, "observation") - policy_device = policy.config.device - - # Load preprocessor and postprocessor from pretrained files - # The stats are embedded in the processor .safetensors files - logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}") - - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=None, # Will load from pretrained processor files - preprocessor_overrides={ - "device_processor": {"device": cfg.policy.device}, - }, - ) - - logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats") - - relative_step = next( - (s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled), - None, - ) - normalizer_step = next( - (s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)), - None, - ) - if relative_step is not None: - if relative_step.action_names is None: - cfg_names = getattr(cfg.policy, "action_feature_names", None) - if cfg_names: - relative_step.action_names = list(cfg_names) - else: - relative_step.action_names = [ - k for k in robot.robot.action_features if k.endswith(".pos") - ] - logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix") - - get_actions_threshold = cfg.action_queue_size_to_get_new_actions - - if not cfg.rtc.enabled: - get_actions_threshold = 0 - - while not shutdown_event.is_set(): - if action_queue.qsize() <= get_actions_threshold: - current_time = time.perf_counter() - action_index_before_inference = action_queue.get_action_index() - prev_actions = action_queue.get_left_over() - - inference_latency = latency_tracker.max() - inference_delay = math.ceil(inference_latency / time_per_chunk) - - obs = robot.get_observation() - - # Apply robot observation processor - obs_processed = robot_observation_processor(obs) - - obs_with_policy_features = build_dataset_frame( - dataset_features, obs_processed, prefix="observation" - ) - - for name in obs_with_policy_features: - obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) - if "image" in name: - obs_with_policy_features[name] = ( - obs_with_policy_features[name].type(torch.float32) / 255 - ) - obs_with_policy_features[name] = ( - obs_with_policy_features[name].permute(2, 0, 1).contiguous() - ) - obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) - obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) - - obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string! - obs_with_policy_features["robot_type"] = ( - robot.robot.name if hasattr(robot.robot, "name") else "" - ) - - preproceseded_obs = preprocessor(obs_with_policy_features) - - # Re-anchor leftover actions for relative-action policies. - # We need the *postprocessed* (absolute) leftover, not the original - # (normalized/relative) one that get_left_over() returns. - if ( - prev_actions is not None - and relative_step is not None - and OBS_STATE in obs_with_policy_features - ): - with action_queue.lock: - if action_queue.queue is not None: - prev_actions_abs = action_queue.queue[action_queue.last_index :].clone() - else: - prev_actions_abs = None - if prev_actions_abs is not None and prev_actions_abs.numel() > 0: - prev_actions = _reanchor_relative_rtc_prefix( - prev_actions_absolute=prev_actions_abs, - current_state=obs_with_policy_features[OBS_STATE], - relative_step=relative_step, - normalizer_step=normalizer_step, - policy_device=policy_device, - ) - - # Generate actions WITH RTC - actions = policy.predict_action_chunk( - preproceseded_obs, - inference_delay=inference_delay, - prev_chunk_left_over=prev_actions, - ) - - # Store original actions (before postprocessing) for RTC - original_actions = actions.squeeze(0).clone() - - postprocessed_actions = postprocessor(actions) - - postprocessed_actions = postprocessed_actions.squeeze(0) - - new_latency = time.perf_counter() - current_time - new_delay = math.ceil(new_latency / time_per_chunk) - latency_tracker.add(new_latency) - - if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay: - logger.warning( - "[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon." - ) - - action_queue.merge( - original_actions, postprocessed_actions, new_delay, action_index_before_inference - ) - else: - # Small sleep to prevent busy waiting - time.sleep(0.1) - - logger.info("[GET_ACTIONS] get actions thread shutting down") - except Exception as e: - logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}") - logger.error(traceback.format_exc()) - sys.exit(1) - - -def actor_control( - robot: RobotWrapper, - robot_action_processor, - action_queue: ActionQueue, - shutdown_event: Event, - cfg: RTCDemoConfig, -): - """Thread function to execute actions on the robot. - - Args: - robot: The robot instance - action_queue: Queue to get actions from - shutdown_event: Event to signal shutdown - cfg: Demo configuration - """ - try: - logger.info("[ACTOR] Starting actor thread") - - action_keys = [k for k in robot.action_features() if k.endswith(".pos")] - - action_count = 0 - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) - action_interval = interpolator.get_control_interval(cfg.fps) - - while not shutdown_event.is_set(): - start_time = time.perf_counter() - - if interpolator.needs_new_action(): - new_action = action_queue.get() - if new_action is not None: - interpolator.add(new_action.cpu()) - - action = interpolator.get() - if action is not None: - action = action.cpu() - action_dict = {key: action[i].item() for i, key in enumerate(action_keys)} - action_processed = robot_action_processor((action_dict, None)) - robot.send_action(action_processed) - action_count += 1 - - dt_s = time.perf_counter() - start_time - time.sleep(max(0, (action_interval - dt_s) - 0.001)) - - logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}") - except Exception as e: - logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}") - logger.error(traceback.format_exc()) - sys.exit(1) - - -def _apply_torch_compile(policy, cfg: RTCDemoConfig): - """Apply torch.compile to the policy's predict_action_chunk method. - - Args: - policy: Policy instance to compile - cfg: Configuration containing torch compile settings - - Returns: - Policy with compiled predict_action_chunk method - """ - - # PI models handle their own compilation - if policy.type == "pi05" or policy.type == "pi0": - return policy - - try: - # Check if torch.compile is available (PyTorch 2.0+) - if not hasattr(torch, "compile"): - logger.warning( - f"torch.compile is not available. Requires PyTorch 2.0+. " - f"Current version: {torch.__version__}. Skipping compilation." - ) - return policy - - logger.info("Applying torch.compile to predict_action_chunk...") - logger.info(f" Backend: {cfg.torch_compile_backend}") - logger.info(f" Mode: {cfg.torch_compile_mode}") - logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}") - - # Compile the predict_action_chunk method - # - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t) - compile_kwargs = { - "backend": cfg.torch_compile_backend, - "mode": cfg.torch_compile_mode, - } - - # Disable CUDA graphs if requested (prevents tensor aliasing issues) - if cfg.torch_compile_disable_cudagraphs: - compile_kwargs["options"] = {"triton.cudagraphs": False} - - original_method = policy.predict_action_chunk - compiled_method = torch.compile(original_method, **compile_kwargs) - policy.predict_action_chunk = compiled_method - logger.info("✓ Successfully compiled predict_action_chunk") - - except Exception as e: - logger.error(f"Failed to apply torch.compile: {e}") - logger.warning("Continuing without torch.compile") - - return policy - - -@parser.wrap() -def demo_cli(cfg: RTCDemoConfig): - """Main entry point for RTC demo with draccus configuration.""" - - # Initialize logging - init_logging() - - logger.info(f"Using device: {cfg.device}") - - # Setup signal handler for graceful shutdown - signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False) - shutdown_event = signal_handler.shutdown_event - - policy = None - robot = None - get_actions_thread = None - actor_thread = None - - policy_class = get_policy_class(cfg.policy.type) - - # Load config and set compile_model for pi0/pi05 models - config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) - - if cfg.policy.type == "pi05" or cfg.policy.type == "pi0": - config.compile_model = cfg.use_torch_compile - - if config.use_peft: - from peft import PeftConfig, PeftModel - - peft_pretrained_path = cfg.policy.pretrained_path - peft_config = PeftConfig.from_pretrained(peft_pretrained_path) - - policy = policy_class.from_pretrained( - pretrained_name_or_path=peft_config.base_model_name_or_path, config=config - ) - policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) - else: - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) - - # Turn on RTC - policy.config.rtc_config = cfg.rtc - - # Init RTC processort, as by default if RTC disabled in the config - # The processor won't be created - policy.init_rtc_processor() - - assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC" - - policy = policy.to(cfg.device) - policy.eval() - - # Apply torch.compile to predict_action_chunk method if enabled - if cfg.use_torch_compile: - policy = _apply_torch_compile(policy, cfg) - - # Create robot - logger.info(f"Initializing robot: {cfg.robot.type}") - robot = make_robot_from_config(cfg.robot) - robot.connect() - robot_wrapper = RobotWrapper(robot) - - # Create robot observation processor - robot_observation_processor = make_default_robot_observation_processor() - robot_action_processor = make_default_robot_action_processor() - - # Create action queue for communication between threads - action_queue = ActionQueue(cfg.rtc) - - # Start chunk requester thread - get_actions_thread = Thread( - target=get_actions, - args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg), - daemon=True, - name="GetActions", - ) - get_actions_thread.start() - logger.info("Started get actions thread") - - # Start action executor thread - actor_thread = Thread( - target=actor_control, - args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg), - daemon=True, - name="Actor", - ) - actor_thread.start() - logger.info("Started actor thread") - - logger.info("Started stop by duration thread") - - # Main thread monitors for duration or shutdown - logger.info(f"Running demo for {cfg.duration} seconds...") - start_time = time.time() - - while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration: - time.sleep(10) - - # Log queue status periodically - if int(time.time() - start_time) % 5 == 0: - logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}") - - if time.time() - start_time > cfg.duration: - break - - logger.info("Demo duration reached or shutdown requested") - - # Signal shutdown - shutdown_event.set() - - # Wait for threads to finish - if get_actions_thread and get_actions_thread.is_alive(): - logger.info("Waiting for chunk requester thread to finish...") - get_actions_thread.join() - - if actor_thread and actor_thread.is_alive(): - logger.info("Waiting for action executor thread to finish...") - actor_thread.join() - - # Cleanup robot - if robot: - robot.disconnect() - logger.info("Robot disconnected") - - logger.info("Cleanup completed") - - -if __name__ == "__main__": - demo_cli() - logging.info("RTC demo finished") diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index fb5204997..63def68d0 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -14,13 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import time + from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.common.control_utils import init_keyboard_listener +from lerobot.common.control_utils import init_keyboard_listener, predict_action from lerobot.configs import FeatureType, PolicyFeature from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.model.kinematics import RobotKinematics from lerobot.policies import make_pre_post_processors from lerobot.policies.act import ACTPolicy +from lerobot.policies.utils import make_robot_action from lerobot.processor import ( RobotProcessorPipeline, make_default_teleop_action_processor, @@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.scripts.lerobot_record import record_loop from lerobot.types import RobotAction, RobotObservation -from lerobot.utils.feature_utils import combine_feature_dicts +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts +from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import init_rerun +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data NUM_EPISODES = 5 FPS = 30 @@ -49,6 +54,9 @@ HF_DATASET_ID = "/" def main(): + # NOTE: For production policy deployment, use `lerobot-rollout` CLI instead. + # This script provides a self-contained example for educational purposes. + # Create the robot configuration & robot camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} robot_config = SO100FollowerConfig( @@ -143,43 +151,67 @@ def main(): raise ValueError("Robot is not connected!") print("Starting evaluate loop...") + control_interval = 1 / FPS episode_idx = 0 for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) + # Inline evaluation loop: predict actions and send to robot + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < EPISODE_TIME_SEC: + start_loop_t = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + break + + # Get robot observation + obs = robot.get_observation() + obs_processed = robot_joints_to_ee_pose_processor(obs) + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) + + # Predict action using the policy + action_tensor = predict_action( + observation=observation_frame, + policy=policy, + device=policy.config.device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.device.type == "cuda", + task=TASK_DESCRIPTION, + robot_type=robot.name, + ) + + # Convert policy output to robot action dict + action_values = make_robot_action(action_tensor, dataset.features) + + # Process and send action to robot (EE -> joints via IK) + robot_action_to_send = robot_ee_to_joints_processor((action_values, obs)) + robot.send_action(robot_action_to_send) + + # Write to dataset + action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) + frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION} + dataset.add_frame(frame) + + log_rerun_data(observation=obs_processed, action=action_values) + + dt_s = time.perf_counter() - start_loop_t + sleep_time_s = control_interval - dt_s + if sleep_time_s < 0: + logging.warning( + f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)." + ) + precise_sleep(max(sleep_time_s, 0.0)) + timestamp = time.perf_counter() - start_episode_t # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] ): log_say("Reset the environment") - record_loop( - robot=robot, - events=events, - fps=FPS, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) + log_say("Waiting for environment reset, press right arrow key when ready...") if events["rerecord_episode"]: log_say("Re-record episode") @@ -190,7 +222,6 @@ def main(): # Save episode dataset.save_episode() - episode_idx += 1 finally: # Clean up log_say("Stop recording") diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index a7ac5bb80..a0b92da3b 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -62,21 +62,20 @@ def main(): follower = SO100Follower(follower_config) leader = SO100Leader(leader_config) - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf + # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: + # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf follower_kinematics_solver = RobotKinematics( urdf_path="./SO101/so101_new_calib.urdf", target_frame_name="gripper_frame_link", joint_names=list(follower.bus.motors.keys()), ) - - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf leader_kinematics_solver = RobotKinematics( urdf_path="./SO101/so101_new_calib.urdf", target_frame_name="gripper_frame_link", joint_names=list(leader.bus.motors.keys()), ) - # Build pipeline to convert follower joints to EE observation + # Build pipeline to convert follower joints to EE observation. follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation]( steps=[ ForwardKinematicsJointsToEE( @@ -87,7 +86,7 @@ def main(): to_output=transition_to_observation, ) - # Build pipeline to convert leader joints to EE action + # Build pipeline to convert leader joints to EE action. leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( steps=[ ForwardKinematicsJointsToEE( @@ -98,9 +97,9 @@ def main(): to_output=transition_to_robot_action, ) - # Build pipeline to convert EE action to follower joints + # Build pipeline to convert EE action to follower joints (with safety bounds). ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( - [ + steps=[ EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.10, @@ -115,13 +114,12 @@ def main(): to_output=transition_to_robot_action, ) - # Create the dataset + # Create the dataset, deriving features from the pipelines so the on-disk schema + # matches exactly what the pipelines produce at runtime. dataset = LeRobotDataset.create( repo_id=HF_REPO_ID, fps=FPS, features=combine_feature_dicts( - # Run the feature contract of the pipelines - # This tells you how the features would look like after the pipeline steps aggregate_pipeline_dataset_features( pipeline=leader_joints_to_ee, initial_features=create_initial_features(action=leader.action_features), @@ -144,7 +142,7 @@ def main(): # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() - init_rerun(session_name="recording_phone") + init_rerun(session_name="recording_so100_ee") try: if not leader.is_connected or not follower.is_connected: @@ -160,14 +158,14 @@ def main(): robot=follower, events=events, fps=FPS, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, teleop=leader, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=leader_joints_to_ee, - robot_action_processor=ee_to_follower_joints, - robot_observation_processor=follower_joints_to_ee, ) # Reset the environment if not stopping or re-recording @@ -179,13 +177,13 @@ def main(): robot=follower, events=events, fps=FPS, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, teleop=leader, control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=leader_joints_to_ee, - robot_action_processor=ee_to_follower_joints, - robot_observation_processor=follower_joints_to_ee, ) if events["rerecord_episode"]: diff --git a/examples/so100_to_so100_EE/rollout.py b/examples/so100_to_so100_EE/rollout.py new file mode 100644 index 000000000..d608bfab2 --- /dev/null +++ b/examples/so100_to_so100_EE/rollout.py @@ -0,0 +1,134 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a trained EE-space policy on SO100 without recording (base rollout). + +Uses the rollout engine's :class:`BaseStrategy` (autonomous execution, +no dataset) with :class:`SyncInferenceConfig` (inline policy call per +control tick). The custom observation/action processors convert between +joint space (robot hardware) and end-effector space (policy I/O) via +forward/inverse kinematics. +""" + +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.configs import PreTrainedConfig +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + RobotProcessorPipeline, + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies import BaseStrategy +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging + +FPS = 30 +DURATION_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" + + +def main(): + init_logging() + + # Robot configuration — the rollout engine will connect it inside build_rollout_context. + camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} + robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, + ) + + # Kinematic solver: we need the motor-name list, so peek at the robot once. + # (The rollout engine owns the connected instance; we only use this for introspection.) + temp_robot = SO100Follower(robot_config) + motor_names = list(temp_robot.bus.motors.keys()) + + # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: + # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=motor_names, + ) + + # Joint-space observation → EE-space observation (consumed by the policy). + robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + + # EE-space action (produced by the policy) → joint-space action (sent to robot). + robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=motor_names, + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + # Policy config (full model is loaded inside build_rollout_context). + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_config.pretrained_path = HF_MODEL_ID + + cfg = RolloutConfig( + robot=robot_config, + policy=policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=FPS, + duration=DURATION_SEC, + task=TASK_DESCRIPTION, + ) + + signal_handler = ProcessSignalHandler(use_threads=True) + + # Pass the EE kinematic processors via kwargs; the defaults (identity) would + # otherwise skip the joint↔EE conversion and the policy would receive the + # wrong observation/action space. + ctx = build_rollout_context( + cfg, + signal_handler.shutdown_event, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + strategy = BaseStrategy(cfg.strategy) + try: + strategy.setup(ctx) + strategy.run(ctx) + finally: + strategy.teardown(ctx) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 790c7f2d9..d3d0c0ed3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -289,6 +289,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" +lerobot-rollout="lerobot.scripts.lerobot_rollout:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.package-data] diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py index 3ddaec1af..ab74c3cd3 100644 --- a/src/lerobot/configs/__init__.py +++ b/src/lerobot/configs/__init__.py @@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies Import them directly: ``from lerobot.configs.train import TrainPipelineConfig`` """ +from .dataset import DatasetRecordConfig from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from .policies import PreTrainedConfig from .types import ( @@ -39,6 +40,7 @@ __all__ = [ "PolicyFeature", "RTCAttentionSchedule", # Config classes + "DatasetRecordConfig", "DatasetConfig", "EvalConfig", "PeftConfig", diff --git a/src/lerobot/configs/dataset.py b/src/lerobot/configs/dataset.py new file mode 100644 index 000000000..e3e17e62b --- /dev/null +++ b/src/lerobot/configs/dataset.py @@ -0,0 +1,80 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``.""" + +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + + +@dataclass +class DatasetRecordConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str = "" + # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") + single_task: str = "" + # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. + root: str | Path | None = None + # Limit the frames per second. + fps: int = 30 + # Number of seconds for data recording for each episode. + episode_time_s: int | float = 60 + # Number of seconds for resetting the environment after each episode. + reset_time_s: int | float = 60 + # Number of episodes to record. + num_episodes: int = 50 + # Encode frames in the dataset into video + video: bool = True + # Upload dataset to Hugging Face hub. + push_to_hub: bool = True + # Upload on private repository on the Hugging Face hub. + private: bool = False + # Add tags to your dataset on the hub. + tags: list[str] | None = None + # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; + # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes + # and threads depends on your system. We recommend 4 threads per camera with 0 processes. + # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. + num_image_writer_processes: int = 0 + # Number of threads writing the frames as png images on disk, per camera. + # Too many threads might cause unstable teleoperation fps due to main thread being blocked. + # Not enough threads might cause low camera fps. + num_image_writer_threads_per_camera: int = 4 + # Number of episodes to record before batch encoding videos + # Set to 1 for immediate encoding (default behavior), or higher for batched encoding + video_encoding_batch_size: int = 1 + # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto', + # or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'. + # Use 'auto' to auto-detect the best available hardware encoder. + vcodec: str = "libsvtav1" + # Enable streaming video encoding: encode frames in real-time during capture instead + # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding + streaming_encoding: bool = False + # Maximum number of frames to buffer per camera when using streaming encoding. + # ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up. + encoder_queue_maxsize: int = 30 + # Number of threads per encoder instance. None = auto (codec default). + # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. + encoder_threads: int | None = None + + def stamp_repo_id(self) -> None: + """Append a date-time tag to ``repo_id`` so each recording session gets a unique name. + + Must be called explicitly at dataset *creation* time — not on resume, + where the existing ``repo_id`` (already stamped) must be preserved. + """ + if self.repo_id: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.repo_id = f"{self.repo_id}_{timestamp}" diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 644ce14db..b6ab0f5f0 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -630,6 +630,8 @@ class LeRobotDataset(torch.utils.data.Dataset): streaming_encoding: bool = False, encoder_queue_maxsize: int = 30, encoder_threads: int | None = None, + video_files_size_in_mb: int | None = None, + data_files_size_in_mb: int | None = None, ) -> "LeRobotDataset": """Create a new LeRobotDataset from scratch for recording data. @@ -677,6 +679,8 @@ class LeRobotDataset(torch.utils.data.Dataset): root=root, use_videos=use_videos, metadata_buffer_size=metadata_buffer_size, + video_files_size_in_mb=video_files_size_in_mb, + data_files_size_in_mb=data_files_size_in_mb, ) obj.repo_id = obj.meta.repo_id obj._requested_root = obj.meta.root diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index e138a84d9..905276642 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator + from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors @@ -21,7 +23,6 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config from .pretrained import PreTrainedPolicy as PreTrainedPolicy -from .rtc import ActionInterpolator as ActionInterpolator from .sac.configuration_sac import SACConfig as SACConfig from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig from .sarm.configuration_sarm import SARMConfig as SARMConfig diff --git a/src/lerobot/policies/rtc/__init__.py b/src/lerobot/policies/rtc/__init__.py index 7a29dcac0..16417b3cd 100644 --- a/src/lerobot/policies/rtc/__init__.py +++ b/src/lerobot/policies/rtc/__init__.py @@ -19,6 +19,7 @@ from .action_queue import ActionQueue from .configuration_rtc import RTCConfig from .latency_tracker import LatencyTracker from .modeling_rtc import RTCProcessor +from .relative import reanchor_relative_rtc_prefix __all__ = [ "ActionInterpolator", @@ -26,4 +27,5 @@ __all__ = [ "LatencyTracker", "RTCConfig", "RTCProcessor", + "reanchor_relative_rtc_prefix", ] diff --git a/src/lerobot/policies/rtc/action_interpolator.py b/src/lerobot/policies/rtc/action_interpolator.py index 222dc33b5..c30481d3b 100644 --- a/src/lerobot/policies/rtc/action_interpolator.py +++ b/src/lerobot/policies/rtc/action_interpolator.py @@ -1,116 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility. +from lerobot.utils.action_interpolator import ActionInterpolator -"""Action interpolation for smoother robot control. - -Provides configurable Nx control rate by interpolating between consecutive actions. -Useful with RTC and action-chunking policies to reduce jerkiness. -""" - -from torch import Tensor - - -class ActionInterpolator: - """Interpolates between consecutive actions for smoother control. - - When enabled with multiplier N, produces N actions per policy action - by linearly interpolating between the previous and current action. - - Example with multiplier=3: - prev_action -> [1/3 interpolated, 2/3 interpolated, current_action] - - This effectively multiplies the control rate for smoother motion. - - Usage: - interpolator = ActionInterpolator(multiplier=2) # 2x control rate - - # In control loop: - if interpolator.needs_new_action(): - new_action = queue.get() - if new_action: - interpolator.add(new_action.cpu()) - - action = interpolator.get() - if action: - robot.send_action(action) - """ - - def __init__(self, multiplier: int = 1): - """Initialize the interpolator. - - Args: - multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.) - """ - if multiplier < 1: - raise ValueError(f"multiplier must be >= 1, got {multiplier}") - self.multiplier = multiplier - self._prev: Tensor | None = None - self._buffer: list[Tensor] = [] - self._idx = 0 - - @property - def enabled(self) -> bool: - """Whether interpolation is active (multiplier > 1).""" - return self.multiplier > 1 - - def reset(self): - """Reset interpolation state (call between episodes).""" - self._prev = None - self._buffer = [] - self._idx = 0 - - def needs_new_action(self) -> bool: - """Check if a new action is needed from the queue.""" - return self._idx >= len(self._buffer) - - def add(self, action: Tensor) -> None: - """Add a new action and compute interpolated sequence. - - Args: - action: New action tensor from policy/queue (already on CPU). - """ - if self.multiplier > 1 and self._prev is not None: - self._buffer = [] - for i in range(1, self.multiplier + 1): - t = i / self.multiplier - interp = self._prev + t * (action - self._prev) - self._buffer.append(interp) - else: - # First step: no previous action yet, so run at base FPS without interpolation. - self._buffer = [action.clone()] - self._prev = action.clone() - self._idx = 0 - - def get(self) -> Tensor | None: - """Get the next interpolated action. - - Returns: - Next action tensor, or None if buffer is exhausted. - """ - if self._idx >= len(self._buffer): - return None - action = self._buffer[self._idx] - self._idx += 1 - return action - - def get_control_interval(self, fps: float) -> float: - """Get the control interval based on interpolation multiplier. - - Args: - fps: Base frames per second. - - Returns: - Control interval in seconds (divided by multiplier). - """ - return 1.0 / (fps * self.multiplier) +__all__ = ["ActionInterpolator"] diff --git a/src/lerobot/policies/rtc/action_queue.py b/src/lerobot/policies/rtc/action_queue.py index dbbdc41df..199257b12 100644 --- a/src/lerobot/policies/rtc/action_queue.py +++ b/src/lerobot/policies/rtc/action_queue.py @@ -92,10 +92,10 @@ class ActionQueue: Returns: int: Number of unconsumed actions. """ - if self.queue is None: - return 0 - length = len(self.queue) - return length - self.last_index + with self.lock: + if self.queue is None: + return 0 + return len(self.queue) - self.last_index def empty(self) -> bool: """Check if the queue is empty. @@ -103,11 +103,10 @@ class ActionQueue: Returns: bool: True if no actions remain, False otherwise. """ - if self.queue is None: - return True - - length = len(self.queue) - return length - self.last_index <= 0 + with self.lock: + if self.queue is None: + return True + return len(self.queue) - self.last_index <= 0 def get_action_index(self) -> int: """Get the current action consumption index. @@ -115,7 +114,8 @@ class ActionQueue: Returns: int: Index of the next action to be consumed. """ - return self.last_index + with self.lock: + return self.last_index def get_left_over(self) -> Tensor | None: """Get leftover original actions for RTC prev_chunk_left_over. diff --git a/src/lerobot/policies/rtc/configuration_rtc.py b/src/lerobot/policies/rtc/configuration_rtc.py index c70fe3de0..3d71edf26 100644 --- a/src/lerobot/policies/rtc/configuration_rtc.py +++ b/src/lerobot/policies/rtc/configuration_rtc.py @@ -35,7 +35,7 @@ class RTCConfig: """ # Infrastructure - enabled: bool = False + enabled: bool = True # Core RTC settings # Todo change to exp diff --git a/src/lerobot/policies/rtc/relative.py b/src/lerobot/policies/rtc/relative.py new file mode 100644 index 000000000..61063f3e2 --- /dev/null +++ b/src/lerobot/policies/rtc/relative.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Relative-action helpers for Real-Time Chunking (RTC).""" + +from __future__ import annotations + +import torch + +from lerobot.processor import ( + NormalizerProcessorStep, + RelativeActionsProcessorStep, + TransitionKey, + create_transition, + to_relative_actions, +) + + +def reanchor_relative_rtc_prefix( + prev_actions_absolute: torch.Tensor, + current_state: torch.Tensor, + relative_step: RelativeActionsProcessorStep, + normalizer_step: NormalizerProcessorStep | None, + policy_device: torch.device | str, +) -> torch.Tensor: + """Convert absolute leftover actions into model-space for relative-action RTC policies. + + When using relative actions, the RTC prefix (previous chunk's unexecuted tail) + is stored in absolute coordinates. Before feeding it back to the policy, this + helper re-expresses those actions relative to the robot's current joint state + and optionally normalizes them so the policy receives correctly scaled inputs. + """ + state = current_state.detach().cpu() + if state.dim() == 1: + state = state.unsqueeze(0) + + action_cpu = prev_actions_absolute.detach().cpu() + mask = relative_step._build_mask(action_cpu.shape[-1]) + relative_actions = to_relative_actions(action_cpu, state, mask) + + transition = create_transition(action=relative_actions) + if normalizer_step is not None: + transition = normalizer_step(transition) + + return transition[TransitionKey.ACTION].to(policy_device) diff --git a/src/lerobot/processor/relative_action_processor.py b/src/lerobot/processor/relative_action_processor.py index d9f97f2c6..e1e65acb1 100644 --- a/src/lerobot/processor/relative_action_processor.py +++ b/src/lerobot/processor/relative_action_processor.py @@ -142,6 +142,10 @@ class RelativeActionsProcessorStep(ProcessorStep): new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask) return new_transition + def get_cached_state(self) -> torch.Tensor | None: + """Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions.""" + return self._last_state + def get_config(self) -> dict[str, Any]: return { "enabled": self.enabled, @@ -182,7 +186,8 @@ class AbsoluteActionsProcessorStep(ProcessorStep): "but relative_step is None. Ensure relative_step is set when constructing the postprocessor." ) - if self.relative_step._last_state is None: + cached_state = self.relative_step.get_cached_state() + if cached_state is None: raise RuntimeError( "AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep " "but no state has been cached. Ensure the preprocessor runs before the postprocessor." @@ -194,9 +199,7 @@ class AbsoluteActionsProcessorStep(ProcessorStep): return new_transition mask = self.relative_step._build_mask(action.shape[-1]) - new_transition[TransitionKey.ACTION] = to_absolute_actions( - action, self.relative_step._last_state, mask - ) + new_transition[TransitionKey.ACTION] = to_absolute_actions(action, cached_state, mask) return new_transition def get_config(self) -> dict[str, Any]: diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 588adffac..eab527250 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -76,6 +76,7 @@ from lerobot.transport.utils import ( ) from lerobot.types import TransitionKey from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( @@ -94,7 +95,6 @@ from .gym_manipulator import ( make_robot_env, step_env_and_process_transition, ) -from .process import ProcessSignalHandler from .queue import get_last_item_from_queue # Main entry point diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index d1207421b..14542576d 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -90,6 +90,7 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, ) from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( @@ -99,7 +100,6 @@ from lerobot.utils.utils import ( from .buffer import ReplayBuffer, concatenate_batch_transitions from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService -from .process import ProcessSignalHandler @parser.wrap() diff --git a/src/lerobot/rollout/__init__.py b/src/lerobot/rollout/__init__.py new file mode 100644 index 000000000..a4de8ee6c --- /dev/null +++ b/src/lerobot/rollout/__init__.py @@ -0,0 +1,87 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy deployment engine with pluggable rollout strategies.""" + +from lerobot.utils.import_utils import require_package + +require_package("datasets", extra="dataset") + +from .configs import ( + BaseStrategyConfig, + DAggerKeyboardConfig, + DAggerPedalConfig, + DAggerStrategyConfig, + HighlightStrategyConfig, + RolloutConfig, + RolloutStrategyConfig, + SentryStrategyConfig, +) +from .context import ( + DatasetContext, + HardwareContext, + PolicyContext, + ProcessorContext, + RolloutContext, + RuntimeContext, + build_rollout_context, +) +from .inference import ( + InferenceEngine, + InferenceEngineConfig, + RTCInferenceConfig, + RTCInferenceEngine, + SyncInferenceConfig, + SyncInferenceEngine, + create_inference_engine, +) +from .strategies import ( + BaseStrategy, + DAggerStrategy, + HighlightStrategy, + RolloutStrategy, + SentryStrategy, + create_strategy, +) + +__all__ = [ + "BaseStrategy", + "BaseStrategyConfig", + "DAggerKeyboardConfig", + "DAggerPedalConfig", + "DAggerStrategy", + "DAggerStrategyConfig", + "DatasetContext", + "HardwareContext", + "HighlightStrategy", + "HighlightStrategyConfig", + "InferenceEngine", + "InferenceEngineConfig", + "PolicyContext", + "ProcessorContext", + "RTCInferenceConfig", + "RTCInferenceEngine", + "RolloutConfig", + "RolloutContext", + "RolloutStrategy", + "RolloutStrategyConfig", + "RuntimeContext", + "SentryStrategy", + "SentryStrategyConfig", + "SyncInferenceConfig", + "SyncInferenceEngine", + "build_rollout_context", + "create_inference_engine", + "create_strategy", +] diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py new file mode 100644 index 000000000..9d019c887 --- /dev/null +++ b/src/lerobot/rollout/configs.py @@ -0,0 +1,323 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration dataclasses for the rollout deployment engine.""" + +from __future__ import annotations + +import abc +import logging +from dataclasses import dataclass, field + +import draccus + +from lerobot.configs import PreTrainedConfig, parser +from lerobot.configs.dataset import DatasetRecordConfig +from lerobot.robots.config import RobotConfig +from lerobot.teleoperators.config import TeleoperatorConfig +from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available + +from .inference import InferenceEngineConfig, SyncInferenceConfig + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Strategy configs (polymorphic dispatch via draccus ChoiceRegistry) +# --------------------------------------------------------------------------- + + +@dataclass +class RolloutStrategyConfig(draccus.ChoiceRegistry, abc.ABC): + """Abstract base for rollout strategy configurations. + + Use ``--strategy.type=`` on the CLI to select a strategy. + """ + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@RolloutStrategyConfig.register_subclass("base") +@dataclass +class BaseStrategyConfig(RolloutStrategyConfig): + """Autonomous rollout with no data recording.""" + + pass + + +@RolloutStrategyConfig.register_subclass("sentry") +@dataclass +class SentryStrategyConfig(RolloutStrategyConfig): + """Continuous autonomous rollout with always-on recording. + + Episode duration is derived from camera resolution, FPS, and + ``target_video_file_size_mb`` so that each saved episode produces a + video file that has crossed the target size. This aligns episode + boundaries with the dataset's video file chunking, so each + ``push_to_hub`` call uploads complete video files rather than + re-uploading a growing file that hasn't crossed the chunk boundary. + """ + + upload_every_n_episodes: int = 5 + # Target video file size in MB for episode rotation. Episodes are + # saved once the estimated video duration would exceed this limit. + # Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None. + target_video_file_size_mb: int | None = None + + +@RolloutStrategyConfig.register_subclass("highlight") +@dataclass +class HighlightStrategyConfig(RolloutStrategyConfig): + """Autonomous rollout with on-demand recording via ring buffer. + + A memory-bounded ring buffer continuously captures telemetry. When + the user presses the save key, the buffer contents are flushed to + the dataset and live recording continues until the key is pressed + again. + """ + + ring_buffer_seconds: float = 10.0 + ring_buffer_max_memory_mb: int = 1024 + save_key: str = "s" + push_key: str = "h" + + +@dataclass +class DAggerKeyboardConfig: + """Keyboard key bindings for DAgger controls. + + Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or + special key names (``"space"``). + """ + + pause_resume: str = "space" + correction: str = "tab" + upload: str = "enter" + + +@dataclass +class DAggerPedalConfig: + """Foot pedal configuration for DAgger controls. + + Pedal codes are evdev key code strings (e.g. ``"KEY_A"``). + """ + + device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + pause_resume: str = "KEY_A" + correction: str = "KEY_B" + upload: str = "KEY_C" + + +@RolloutStrategyConfig.register_subclass("dagger") +@dataclass +class DAggerStrategyConfig(RolloutStrategyConfig): + """Human-in-the-loop data collection (DAgger / RaC). + + Alternates between autonomous policy execution and human intervention. + Intervention frames are tagged with ``intervention=True``. + + Input is controlled via either a keyboard or foot pedal, selected by + ``input_device``. Each device exposes three actions: + + 1. **pause_resume** — toggle policy execution on/off. + 2. **correction** — toggle human correction recording. + 3. **upload** — push dataset to hub on demand (corrections-only mode). + + When ``record_autonomous=False`` (default) only human-correction windows + are recorded — each correction becomes its own episode. Set to ``True`` + to record both autonomous and correction frames with size-based episode + rotation (same as Sentry) and background uploading. ``push_to_hub`` is + blocked while a correction is in progress. + """ + + # Number of correction episodes to collect (corrections-only mode). + # When None, falls back to ``--dataset.num_episodes``. + num_episodes: int | None = None + record_autonomous: bool = False + upload_every_n_episodes: int = 5 + # Target video file size in MB for episode rotation (record_autonomous + # mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None. + target_video_file_size_mb: int | None = None + input_device: str = "keyboard" + keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig) + pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig) + + def __post_init__(self): + if self.input_device not in ("keyboard", "pedal"): + raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'") + + +# --------------------------------------------------------------------------- +# Top-level rollout config +# --------------------------------------------------------------------------- + + +@dataclass +class RolloutConfig: + """Top-level configuration for the ``lerobot-rollout`` CLI. + + Combines hardware, policy, strategy, and runtime settings. The + ``__post_init__`` method performs fail-fast validation to reject + invalid flag combinations early. + """ + + # Hardware + robot: RobotConfig | None = None + teleop: TeleoperatorConfig | None = None + + # Policy (loaded from --policy.path via __post_init__) + policy: PreTrainedConfig | None = None + + # Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger) + strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig) + + # Inference backend (polymorphic: --inference.type=sync|rtc) + inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig) + + # Dataset (required for sentry, highlight, dagger; None for base) + dataset: DatasetRecordConfig | None = None + + # Runtime + fps: float = 30.0 + duration: float = 0.0 # 0 = infinite (24/7 mode) + interpolation_multiplier: int = 1 + device: str | None = None + task: str = "" + display_data: bool = False + # Display data on a remote Rerun server + display_ip: str | None = None + # Port of the remote Rerun server + display_port: int | None = None + # Whether to display compressed images in Rerun + display_compressed_images: bool = False + # Use vocal synthesis to read events + play_sounds: bool = True + resume: bool = False + # Rename map for mapping robot/dataset observation keys to policy keys + rename_map: dict[str, str] = field(default_factory=dict) + + # Hardware teardown + # When True (default), smoothly interpolate the robot back to the joint + # positions captured at startup before disconnecting. Set to False to + # leave the robot in its final achieved pose at shutdown. + return_to_initial_position: bool = True + + # Torch compile + use_torch_compile: bool = False + torch_compile_backend: str = "inductor" + torch_compile_mode: str = "default" + compile_warmup_inferences: int = 2 + + def __post_init__(self): + """Validate config invariants and load the policy config from ``--policy.path``.""" + # --- Strategy-specific validation --- + if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None: + raise ValueError("DAgger strategy requires --teleop.type to be set") + + # TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation. + needs_dataset = isinstance( + self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig) + ) + if needs_dataset and (self.dataset is None or not self.dataset.repo_id): + raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set") + + if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None: + raise ValueError( + "Base strategy does not record data. Use sentry, highlight, or dagger for recording." + ) + + # Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop + if ( + isinstance(self.strategy, SentryStrategyConfig) + and self.dataset is not None + and not self.dataset.streaming_encoding + ): + logger.warning("Sentry mode forces streaming_encoding=True") + self.dataset.streaming_encoding = True + + # Highlight writes frames while the policy is still running, so streaming is mandatory. + if ( + isinstance(self.strategy, HighlightStrategyConfig) + and self.dataset is not None + and not self.dataset.streaming_encoding + ): + logger.warning("Highlight mode forces streaming_encoding=True") + self.dataset.streaming_encoding = True + + # DAgger: streaming is mandatory only when the autonomous phase is also recorded. + if isinstance(self.strategy, DAggerStrategyConfig) and self.dataset is not None: + if self.strategy.record_autonomous and not self.dataset.streaming_encoding: + logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True") + self.dataset.streaming_encoding = True + elif not self.strategy.record_autonomous and not self.dataset.streaming_encoding: + logger.info( + "Streaming encoding is disabled for DAgger corrections-only mode. " + "Consider enabling it for faster episode saving: " + "--dataset.streaming_encoding=true --dataset.encoder_threads=2" + ) + + # DAgger: resolve num_episodes from dataset config when not explicitly set. + if isinstance(self.strategy, DAggerStrategyConfig) and self.strategy.num_episodes is None: + if self.dataset is not None: + self.strategy.num_episodes = self.dataset.num_episodes + logger.info( + "DAgger num_episodes not set — using --dataset.num_episodes=%d", + self.strategy.num_episodes, + ) + else: + raise ValueError( + "DAgger num_episodes must be set either via --strategy.num_episodes or --dataset.num_episodes" + ) + + # --- Policy loading --- + if self.robot is None: + raise ValueError("--robot.type is required for rollout") + + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if self.policy is None: + raise ValueError("--policy.path is required for rollout") + + # --- Task resolution --- + # When any --dataset.* flag is passed, draccus creates a DatasetRecordConfig with single_task="". + # If the user set the task via the top-level --task flag, propagate it so that all + # downstream consumers (inference engine, dataset frame builders) see it. + if self.dataset is not None and not self.dataset.single_task and self.task: + logger.info("Propagating top-level task '%s' to dataset config", self.task) + self.dataset.single_task = self.task + elif self.dataset is not None and self.dataset.single_task and not self.task: + logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task) + self.task = self.dataset.single_task + + # --- Device resolution --- + # Resolve device from the policy config when not explicitly set so all + # components (policy.to, preprocessor, inference engine) use the same + # device string instead of inconsistent fallbacks. + if self.device is None or not is_torch_device_available(self.device): + resolved = self.policy.device + if resolved: + self.device = resolved + logger.info("Resolved device from policy config: %s", self.device) + else: + self.device = auto_select_torch_device().type + logger.info("No policy config to resolve device from; auto-selected device: %s", self.device) + + @classmethod + def __get_path_fields__(cls) -> list[str]: + return ["policy"] diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py new file mode 100644 index 000000000..fe58554ab --- /dev/null +++ b/src/lerobot/rollout/context.py @@ -0,0 +1,459 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rollout context: shared state created once before strategy dispatch. + +Grouped into five topical sub-contexts — :class:`RuntimeContext`, +:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`, +and :class:`DatasetContext` — assembled into :class:`RolloutContext`. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from threading import Event + +import torch + +from lerobot.configs import FeatureType, PreTrainedConfig +from lerobot.datasets import ( + LeRobotDataset, + aggregate_pipeline_dataset_features, + create_initial_features, +) +from lerobot.policies import get_policy_class, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import ( + PolicyProcessorPipeline, + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_processors, + rename_stats, +) +from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep +from lerobot.robots import make_robot_from_config +from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config +from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features + +from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig +from .inference import ( + InferenceEngine, + RTCInferenceConfig, + SyncInferenceConfig, + create_inference_engine, +) +from .robot_wrapper import ThreadSafeRobot + +logger = logging.getLogger(__name__) + + +def _resolve_action_key_order( + policy_action_names: list[str] | None, dataset_action_names: list[str] +) -> list[str]: + """Choose action name ordering for mapping policy tensor outputs to robot action dicts.""" + if not policy_action_names: + return dataset_action_names + policy_action_names = list(policy_action_names) + if len(policy_action_names) != len(dataset_action_names): + logger.warning( + "policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order", + len(policy_action_names), + len(dataset_action_names), + ) + return dataset_action_names + if set(dataset_action_names) != set(policy_action_names): + logger.warning("policy.action_feature_names keys don't match dataset; using dataset order") + return dataset_action_names + return policy_action_names + + +# --------------------------------------------------------------------------- +# Sub-contexts +# --------------------------------------------------------------------------- + + +@dataclass +class RuntimeContext: + """Runtime knobs shared with every strategy.""" + + cfg: RolloutConfig + shutdown_event: Event + + +@dataclass +class HardwareContext: + """Connected hardware. + + The raw robot is available via ``robot_wrapper.inner`` when needed + (e.g. for disconnect); strategies should otherwise go through the + thread-safe wrapper. + + ``initial_position`` stores the robot's joint positions at connect + time. Strategies use it to return the robot to a safe pose before + shutting down. + """ + + robot_wrapper: ThreadSafeRobot + teleop: Teleoperator | None + initial_position: dict | None = None + + +@dataclass +class PolicyContext: + """Loaded policy and its inference engine.""" + + policy: PreTrainedPolicy + preprocessor: PolicyProcessorPipeline + postprocessor: PolicyProcessorPipeline + inference: InferenceEngine + + +@dataclass +class ProcessorContext: + """Robot-side pipelines (run outside the policy).""" + + teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction] + robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction] + robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation] + + +@dataclass +class DatasetContext: + """Dataset and feature bookkeeping.""" + + dataset: LeRobotDataset | None + dataset_features: dict = field(default_factory=dict) + hw_features: dict = field(default_factory=dict) + ordered_action_keys: list[str] = field(default_factory=list) + + +@dataclass +class RolloutContext: + """Bundle of sub-contexts passed to every rollout strategy. + + Built once by :func:`build_rollout_context` before strategy dispatch. + """ + + runtime: RuntimeContext + hardware: HardwareContext + policy: PolicyContext + processors: ProcessorContext + data: DatasetContext + + +# --------------------------------------------------------------------------- +# Build +# --------------------------------------------------------------------------- + + +def build_rollout_context( + cfg: RolloutConfig, + shutdown_event: Event, + teleop_action_processor: RobotProcessorPipeline | None = None, + robot_action_processor: RobotProcessorPipeline | None = None, + robot_observation_processor: RobotProcessorPipeline | None = None, +) -> RolloutContext: + """Wire up policy, processors, hardware, dataset, and inference engine. + + The order is policy-first / hardware-last so a bad ``--policy.path`` + fails fast without touching the robot. + """ + is_rtc = isinstance(cfg.inference, RTCInferenceConfig) + + # --- 1. Policy (heavy I/O, but no hardware yet) ------------------- + logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path) + policy_config = cfg.policy + policy_class = get_policy_class(policy_config.type) + + full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) + for attr in ("device", "use_amp"): + if hasattr(cfg.policy, attr) and hasattr(full_config, attr): + cli_val = getattr(cfg.policy, attr) + if cli_val is not None: + setattr(full_config, attr, cli_val) + + if hasattr(full_config, "compile_model"): + full_config.compile_model = cfg.use_torch_compile + + if full_config.type == "vqbet" and cfg.device == "mps": + raise NotImplementedError( + "Current implementation of VQBeT does not support `mps` backend. " + "Please use `cpu` or `cuda` backend." + ) + + if full_config.use_peft: + from peft import PeftConfig, PeftModel + + peft_path = cfg.policy.pretrained_path + peft_config = PeftConfig.from_pretrained(peft_path) + policy = policy_class.from_pretrained( + pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config + ) + policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config) + else: + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config) + + if is_rtc: + policy.config.rtc_config = cfg.inference.rtc + if hasattr(policy, "init_rtc_processor"): + policy.init_rtc_processor() + + policy = policy.to(cfg.device) + policy.eval() + logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device) + + if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"): + try: + if hasattr(torch, "compile"): + compile_kwargs = { + "backend": cfg.torch_compile_backend, + "mode": cfg.torch_compile_mode, + "options": {"triton.cudagraphs": False}, + } + policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs) + logger.info("torch.compile applied to predict_action_chunk") + except Exception as e: + logger.warning("Failed to apply torch.compile: %s", e) + + # --- 2. Robot-side processors (user-supplied or defaults) -------- + if ( + teleop_action_processor is None + or robot_action_processor is None + or robot_observation_processor is None + ): + _t, _r, _o = make_default_processors() + teleop_action_processor = teleop_action_processor or _t + robot_action_processor = robot_action_processor or _r + robot_observation_processor = robot_observation_processor or _o + + # --- 3. Hardware (heaviest side-effect, deferred) ----------------- + logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?") + robot = make_robot_from_config(cfg.robot) + robot.connect() + logger.info("Robot connected: %s", robot.name) + + # Store the initial joint positions so we can return to a safe pose on shutdown. + initial_obs = robot.get_observation() + initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")} + logger.info("Captured initial robot position (%d keys)", len(initial_position)) + + robot_wrapper = ThreadSafeRobot(robot) + + teleop = None + if cfg.teleop is not None: + logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?") + teleop = make_teleoperator_from_config(cfg.teleop) + teleop.connect() + logger.info("Teleoperator connected") + + # TODO(Steven): once Teleoperator motor-control methods are standardised + # (``enable_torque`` / ``disable_torque`` / ``write_goal_positions``), gate + # the DAgger strategy on their presence here and fail fast with a helpful + # message instead of relying on the operator to pre-align the leader by + # hand. See :func:`DAggerStrategy._apply_transition` for the matching + # disabled call sites. + # if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None: + # required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions") + # missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))] + # if missing: + # teleop.disconnect() + # raise ValueError( + # f"DAgger strategy requires a teleoperator with motor control methods " + # f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}" + # ) + + # --- 4. Features + action-key reconciliation --------------------- + # TODO(Steven):Only ``.pos`` joint features are routed to the policy as state and as the + # action target; velocity and torque channels (when present) are kept in + # the raw observation but excluded from the policy-facing tensors. + all_obs_features = robot.observation_features + # ``observation_features`` values are either a tuple (camera shape) or the + # ``float`` type itself used as a sentinel for scalar motor features — + # see ``dict[str, type | tuple]`` annotation on ``Robot.observation_features``. + observation_features_hw = { + k: v + for k, v in all_obs_features.items() + if isinstance(v, tuple) or (v is float and k.endswith(".pos")) + } + action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")} + + # The action side is always needed: sync inference reads action names from + # ``dataset_features[ACTION]`` to map policy tensors back to robot actions. + action_dataset_features = aggregate_pipeline_dataset_features( + pipeline=teleop_action_processor, + initial_features=create_initial_features(action=action_features_hw), + use_videos=cfg.dataset.video if cfg.dataset else True, + ) + # Observation-side aggregation is needed because of build_dataset_frame + observation_dataset_features = aggregate_pipeline_dataset_features( + pipeline=robot_observation_processor, + initial_features=create_initial_features(observation=observation_features_hw), + use_videos=cfg.dataset.video if cfg.dataset else True, + ) + dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features) + hw_features = hw_to_dataset_features(observation_features_hw, "observation") + raw_action_keys = list(action_features_hw.keys()) + policy_action_names = getattr(policy_config, "action_feature_names", None) + ordered_action_keys = _resolve_action_key_order( + list(policy_action_names) if policy_action_names else None, + raw_action_keys, + ) + + # Validate visual features if no rename_map is active + rename_map = cfg.rename_map + if not rename_map: + expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL} + provided_visuals = { + f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple) + } + policy_subset = expected_visuals.issubset(provided_visuals) + hw_subset = provided_visuals.issubset(expected_visuals) + if not (policy_subset or hw_subset): + raise ValueError( + f"Visual feature mismatch between policy and robot hardware.\n" + f"Policy expects: {expected_visuals}\n" + f"Robot provides: {provided_visuals}" + ) + + # --- 5. Dataset ------------- + dataset = None + if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig): + logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id) + if cfg.resume: + dataset = LeRobotDataset.resume( + cfg.dataset.repo_id, + root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + vcodec=cfg.dataset.vcodec, + streaming_encoding=cfg.dataset.streaming_encoding, + encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, + encoder_threads=cfg.dataset.encoder_threads, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(robot.cameras if hasattr(robot, "cameras") else []), + ) + else: + if isinstance(cfg.strategy, DAggerStrategyConfig): + dataset_features["intervention"] = { + "dtype": "bool", + "shape": (1,), + "names": None, + } + + repo_name = cfg.dataset.repo_id.split("/", 1)[-1] + if not repo_name.startswith("rollout_"): + raise ValueError( + "Dataset names for rollout must start with 'rollout_'. " + "Use --dataset.repo_id=/rollout_ for policy deployment datasets." + ) + cfg.dataset.stamp_repo_id() + target_video_mb = getattr(cfg.strategy, "target_video_file_size_mb", None) + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(robot.cameras if hasattr(robot, "cameras") else []), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + vcodec=cfg.dataset.vcodec, + streaming_encoding=cfg.dataset.streaming_encoding, + encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, + encoder_threads=cfg.dataset.encoder_threads, + video_files_size_in_mb=target_video_mb, + ) + + if dataset is not None: + logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes) + + # --- 6. Policy pre/post processors (needs dataset stats if any) --- + dataset_stats = None + if dataset is not None: + dataset_stats = rename_stats( + dataset.meta.stats, + cfg.rename_map, + ) + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy_config, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=dataset_stats, + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.rename_map}, + }, + ) + + if isinstance(cfg.inference, SyncInferenceConfig) and any( + isinstance(step, RelativeActionsProcessorStep) and step.enabled + for step in getattr(preprocessor, "steps", ()) + ): + raise NotImplementedError( + "SyncInferenceEngine does not support policies with relative actions for now." + "Use --inference.type=rtc or remove relative action processor steps from the policy pipeline." + ) + + # --- 7. Inference strategy (needs policy + pre/post + hardware) -- + logger.info( + "Creating inference engine (type=%s)...", + cfg.inference.type if hasattr(cfg.inference, "type") else "sync", + ) + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + inference_strategy = create_inference_engine( + cfg.inference, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + robot_wrapper=robot_wrapper, + hw_features=hw_features, + dataset_features=dataset_features, + ordered_action_keys=ordered_action_keys, + task=task_str, + fps=cfg.fps, + device=cfg.device, + use_torch_compile=cfg.use_torch_compile, + compile_warmup_inferences=cfg.compile_warmup_inferences, + shutdown_event=shutdown_event, + ) + + # --- 8. Assemble --------------------------------------------------- + logger.info("Rollout context assembled successfully") + return RolloutContext( + runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event), + hardware=HardwareContext( + robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position + ), + policy=PolicyContext( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + inference=inference_strategy, + ), + processors=ProcessorContext( + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ), + data=DatasetContext( + dataset=dataset, + dataset_features=dataset_features, + hw_features=hw_features, + ordered_action_keys=ordered_action_keys, + ), + ) diff --git a/src/lerobot/rollout/inference/__init__.py b/src/lerobot/rollout/inference/__init__.py new file mode 100644 index 000000000..b61cb342c --- /dev/null +++ b/src/lerobot/rollout/inference/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference engine package — backend-agnostic action production. + +Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so +rollout strategies never branch on which backend is in use. +""" + +from .base import InferenceEngine +from .factory import ( + InferenceEngineConfig, + RTCInferenceConfig, + SyncInferenceConfig, + create_inference_engine, +) +from .rtc import RTCInferenceEngine +from .sync import SyncInferenceEngine + +__all__ = [ + "InferenceEngine", + "InferenceEngineConfig", + "RTCInferenceConfig", + "RTCInferenceEngine", + "SyncInferenceConfig", + "SyncInferenceEngine", + "create_inference_engine", +] diff --git a/src/lerobot/rollout/inference/base.py b/src/lerobot/rollout/inference/base.py new file mode 100644 index 000000000..f269aa5fe --- /dev/null +++ b/src/lerobot/rollout/inference/base.py @@ -0,0 +1,89 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference engine ABC. + +Rollout strategies consume actions through this small interface so they +do not need to know whether inference happens inline on the control thread +or asynchronously in a background thread (RTC). +""" + +from __future__ import annotations + +import abc + +import torch + + +class InferenceEngine(abc.ABC): + """Abstract backend for producing actions during rollout. + + Subclasses decide whether inference happens inline on the control + thread or asynchronously in a background thread. The contract is + minimal so additional backends can be plugged in without touching + rollout strategies. + + Lifecycle + --------- + ``start`` — prepare the backend (e.g. launch a background thread). + ``stop`` — shut the backend down cleanly. + ``reset`` — clear episode-scoped state (policy hidden state, queues…). + + Action production + ----------------- + ``get_action(obs_frame)`` — return the next action tensor, or + ``None`` if none is available (e.g. async queue empty). Sync + backends always compute from ``obs_frame``; async backends ignore + it (they receive observations via ``notify_observation``). + + Optional hooks + -------------- + ``notify_observation`` / ``pause`` / ``resume`` have a no-op default + so rollout strategies can invoke them unconditionally. + """ + + @abc.abstractmethod + def start(self) -> None: + """Initialise the backend.""" + + @abc.abstractmethod + def stop(self) -> None: + """Tear the backend down.""" + + @abc.abstractmethod + def reset(self) -> None: + """Clear episode-scoped state.""" + + @abc.abstractmethod + def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: + """Return the next action tensor, or ``None`` if unavailable.""" + + def notify_observation(self, obs: dict) -> None: # noqa: B027 + """Publish the latest processed observation. Default: no-op.""" + + def pause(self) -> None: # noqa: B027 + """Pause background inference. Default: no-op.""" + + def resume(self) -> None: # noqa: B027 + """Resume background inference. Default: no-op.""" + + @property + def ready(self) -> bool: + """True once the backend can produce actions (e.g. warmup done).""" + return True + + @property + def failed(self) -> bool: + """True if an unrecoverable error occurred in the backend.""" + return False diff --git a/src/lerobot/rollout/inference/factory.py b/src/lerobot/rollout/inference/factory.py new file mode 100644 index 000000000..e600bed63 --- /dev/null +++ b/src/lerobot/rollout/inference/factory.py @@ -0,0 +1,128 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference engine configs and factory. + +Selection is explicit via ``--inference.type=sync|rtc``. Adding a new +backend requires registering its config subclass and dispatching it in +:func:`create_inference_engine`. +""" + +from __future__ import annotations + +import abc +import logging +from dataclasses import dataclass, field +from threading import Event + +import draccus + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.processor import PolicyProcessorPipeline + +from ..robot_wrapper import ThreadSafeRobot +from .base import InferenceEngine +from .rtc import RTCInferenceEngine +from .sync import SyncInferenceEngine + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Configs +# --------------------------------------------------------------------------- + + +@dataclass +class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC): + """Abstract base for inference backend configuration. + + Use ``--inference.type=`` on the CLI to select a backend. + """ + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@InferenceEngineConfig.register_subclass("sync") +@dataclass +class SyncInferenceConfig(InferenceEngineConfig): + """Inline synchronous inference (one policy call per control tick).""" + + +@InferenceEngineConfig.register_subclass("rtc") +@dataclass +class RTCInferenceConfig(InferenceEngineConfig): + """Real-Time Chunking: async policy inference in a background thread.""" + + # Eagerly constructed so draccus exposes nested fields directly on the CLI + # (e.g. ``--inference.rtc.execution_horizon=...``). + rtc: RTCConfig = field(default_factory=RTCConfig) + queue_threshold: int = 30 + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_inference_engine( + config: InferenceEngineConfig, + *, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + robot_wrapper: ThreadSafeRobot, + hw_features: dict, + dataset_features: dict, + ordered_action_keys: list[str], + task: str, + fps: float, + device: str | None, + use_torch_compile: bool = False, + compile_warmup_inferences: int = 2, + shutdown_event: Event | None = None, +) -> InferenceEngine: + """Instantiate the appropriate inference engine from a config object.""" + logger.info("Creating inference engine: %s", config.type) + if isinstance(config, SyncInferenceConfig): + return SyncInferenceEngine( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset_features=dataset_features, + ordered_action_keys=ordered_action_keys, + task=task, + device=device, + robot_type=robot_wrapper.robot_type, + ) + if isinstance(config, RTCInferenceConfig): + return RTCInferenceEngine( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + robot_wrapper=robot_wrapper, + rtc_config=config.rtc, + hw_features=hw_features, + task=task, + fps=fps, + device=device, + use_torch_compile=use_torch_compile, + compile_warmup_inferences=compile_warmup_inferences, + rtc_queue_threshold=config.queue_threshold, + shutdown_event=shutdown_event, + ) + raise ValueError(f"Unknown inference engine type: {type(config).__name__}") diff --git a/src/lerobot/rollout/inference/rtc.py b/src/lerobot/rollout/inference/rtc.py new file mode 100644 index 000000000..0eef62cef --- /dev/null +++ b/src/lerobot/rollout/inference/rtc.py @@ -0,0 +1,360 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Real-Time Chunking inference engine. + +A background thread produces action chunks asynchronously via +:meth:`policy.predict_action_chunk`. The main control loop polls +``get_action`` for the next ready action; observations flow the other +way via ``notify_observation``. +""" + +from __future__ import annotations + +import logging +import math +import time +import traceback +from threading import Event, Lock, Thread +from typing import Any + +import torch + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc import ActionQueue, LatencyTracker, reanchor_relative_rtc_prefix +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.utils import prepare_observation_for_inference +from lerobot.processor import ( + NormalizerProcessorStep, + PolicyProcessorPipeline, + RelativeActionsProcessorStep, +) +from lerobot.utils.feature_utils import build_dataset_frame + +from ..robot_wrapper import ThreadSafeRobot +from .base import InferenceEngine + +logger = logging.getLogger(__name__) + +# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue. +_RTC_IDLE_SLEEP_S: float = 0.01 +# Backoff between transient inference errors (per consecutive failure). +_RTC_ERROR_RETRY_DELAY_S: float = 0.5 +# Consecutive transient errors tolerated before giving up and propagating shutdown. +_RTC_MAX_CONSECUTIVE_ERRORS: int = 10 +# Hard timeout for joining the RTC thread on stop(). +_RTC_JOIN_TIMEOUT_S: float = 3.0 + + +# --------------------------------------------------------------------------- +# RTC helpers +# --------------------------------------------------------------------------- + + +def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor: + """Pad or truncate RTC prefix actions to a fixed length for stable compiled inference.""" + if prev_actions.ndim != 2: + raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}") + steps, action_dim = prev_actions.shape + if steps == target_steps: + return prev_actions + if steps > target_steps: + return prev_actions[:target_steps] + padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device) + padded[:steps] = prev_actions + return padded + + +# --------------------------------------------------------------------------- +# RTCInferenceEngine +# --------------------------------------------------------------------------- + + +class RTCInferenceEngine(InferenceEngine): + """Async RTC inference: a background thread produces action chunks. + + ``get_action`` pops the next action from the shared queue (or + returns ``None`` if the queue is empty). The main loop should call + ``notify_observation`` every tick and ``pause``/``resume`` around + human-intervention phases. + """ + + def __init__( + self, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + robot_wrapper: ThreadSafeRobot, + rtc_config: RTCConfig, + hw_features: dict, + task: str, + fps: float, + device: str | None, + use_torch_compile: bool = False, + compile_warmup_inferences: int = 2, + rtc_queue_threshold: int = 30, + shutdown_event: Event | None = None, + ) -> None: + self._policy = policy + self._preprocessor = preprocessor + self._postprocessor = postprocessor + self._robot = robot_wrapper + self._rtc_config = rtc_config + self._hw_features = hw_features + self._task = task + self._fps = fps + self._device = device or "cpu" + self._use_torch_compile = use_torch_compile + self._compile_warmup_inferences = compile_warmup_inferences + self._rtc_queue_threshold = rtc_queue_threshold + + self._action_queue: ActionQueue | None = None + self._obs_holder: dict[str, Any] = {} + self._obs_lock = Lock() + self._policy_active = Event() + self._compile_warmup_done = Event() + self._shutdown_event = Event() + self._rtc_error = Event() + self._global_shutdown_event = shutdown_event + self._rtc_thread: Thread | None = None + + if not self._use_torch_compile: + self._compile_warmup_done.set() + logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)") + else: + logger.info( + "RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)", + compile_warmup_inferences, + ) + + # Processor introspection for relative-action re-anchoring. + self._relative_step = next( + (s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled), + None, + ) + self._normalizer_step = next( + (s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)), + None, + ) + if self._relative_step is not None: + if self._relative_step.action_names is None: + cfg_names = getattr(policy.config, "action_feature_names", None) + if cfg_names: + self._relative_step.action_names = list(cfg_names) + else: + self._relative_step.action_names = [ + k for k in robot_wrapper.action_features if k.endswith(".pos") + ] + logger.info("Relative actions enabled: RTC prefix will be re-anchored") + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + @property + def ready(self) -> bool: + """True once torch.compile warmup is complete (or immediately if compile is disabled).""" + return self._compile_warmup_done.is_set() + + @property + def failed(self) -> bool: + """True if the RTC background thread exited due to an unrecoverable error.""" + return self._rtc_error.is_set() + + @property + def action_queue(self) -> ActionQueue | None: + """The shared action queue between the RTC thread and the main loop.""" + return self._action_queue + + def start(self) -> None: + """Launch the RTC background thread.""" + self._action_queue = ActionQueue(self._rtc_config) + self._obs_holder = { + "obs": None, + "robot_type": self._robot.robot_type, + } + self._shutdown_event.clear() + self._rtc_thread = Thread( + target=self._rtc_loop, + daemon=True, + name="RTCInference", + ) + self._rtc_thread.start() + logger.info("RTC inference thread started") + + def stop(self) -> None: + """Signal the RTC thread to stop and wait for it.""" + logger.info("Stopping RTC inference thread...") + self._shutdown_event.set() + self._policy_active.clear() + if self._rtc_thread is not None and self._rtc_thread.is_alive(): + self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S) + if self._rtc_thread.is_alive(): + logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S) + else: + logger.info("RTC inference thread stopped") + self._rtc_thread = None + + def pause(self) -> None: + """Pause the RTC background thread.""" + logger.info("Pausing RTC inference thread") + self._policy_active.clear() + + def resume(self) -> None: + """Resume the RTC background thread.""" + logger.info("Resuming RTC inference thread") + self._policy_active.set() + + def reset(self) -> None: + """Reset the policy, processors, and action queue.""" + logger.info("Resetting RTC inference state (policy + processors + queue)") + self._policy.reset() + self._preprocessor.reset() + self._postprocessor.reset() + if self._action_queue is not None: + self._action_queue.clear() + + # ------------------------------------------------------------------ + # Action production (called from main thread) + # ------------------------------------------------------------------ + + def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: + """Pop the next action from the RTC queue (ignores ``obs_frame``).""" + if self._action_queue is None: + return None + return self._action_queue.get() + + def notify_observation(self, obs: dict) -> None: + """Publish the latest observation for the RTC thread to consume.""" + with self._obs_lock: + self._obs_holder["obs"] = obs + + # ------------------------------------------------------------------ + # RTC: background inference thread + # ------------------------------------------------------------------ + + def _rtc_loop(self) -> None: + """Background thread that generates action chunks via RTC.""" + try: + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / self._fps + policy_device = torch.device(self._device) + + warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0 + inference_count = 0 + consecutive_errors = 0 + + while not self._shutdown_event.is_set(): + if not self._policy_active.is_set(): + time.sleep(_RTC_IDLE_SLEEP_S) + continue + + queue = self._action_queue + with self._obs_lock: + obs = self._obs_holder.get("obs") + if queue is None or obs is None: + time.sleep(_RTC_IDLE_SLEEP_S) + continue + + if queue.qsize() <= self._rtc_queue_threshold: + try: + current_time = time.perf_counter() + idx_before = queue.get_action_index() + prev_actions = queue.get_left_over() + + latency = latency_tracker.max() + delay = math.ceil(latency / time_per_chunk) if latency else 0 + + obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation") + obs_batch = prepare_observation_for_inference( + obs_batch, policy_device, self._task, self._robot.robot_type + ) + obs_batch["task"] = [self._task] + + preprocessed = self._preprocessor(obs_batch) + + if prev_actions is not None and self._relative_step is not None: + # Rebase against the raw cached state so the leftover tail stays in + # the training-time coordinate frame. + raw_state = self._relative_step.get_cached_state() + if raw_state is not None: + prev_abs = queue.get_processed_left_over() + if prev_abs is not None and prev_abs.numel() > 0: + prev_actions = reanchor_relative_rtc_prefix( + prev_actions_absolute=prev_abs, + current_state=raw_state, + relative_step=self._relative_step, + normalizer_step=self._normalizer_step, + policy_device=policy_device, + ) + + if prev_actions is not None: + prev_actions = _normalize_prev_actions_length( + prev_actions, target_steps=self._rtc_config.execution_horizon + ) + + actions = self._policy.predict_action_chunk( + preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions + ) + + original = actions.squeeze(0).clone() + processed = self._postprocessor(actions).squeeze(0) + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + + inference_count += 1 + consecutive_errors = 0 + is_warmup = self._use_torch_compile and inference_count <= warmup_required + if is_warmup: + latency_tracker.reset() + else: + latency_tracker.add(new_latency) + + queue.merge(original, processed, new_delay, idx_before) + + if ( + is_warmup + and inference_count >= warmup_required + and not self._compile_warmup_done.is_set() + ): + self._compile_warmup_done.set() + logger.info("Compile warmup complete (%d inferences)", inference_count) + + logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize()) + + except Exception as e: + consecutive_errors += 1 + logger.error( + "RTC inference error (%d/%d): %s", + consecutive_errors, + _RTC_MAX_CONSECUTIVE_ERRORS, + e, + ) + logger.debug(traceback.format_exc()) + if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS: + # Persistent failure: stop retrying and propagate shutdown. + raise + time.sleep(_RTC_ERROR_RETRY_DELAY_S) + else: + time.sleep(_RTC_IDLE_SLEEP_S) + + except Exception as e: + logger.error("Fatal error in RTC thread: %s", e) + logger.error(traceback.format_exc()) + self._rtc_error.set() + # Unblock any warmup waiters so the main loop doesn't spin forever + self._compile_warmup_done.set() + # Signal the top-level shutdown so strategies exit their control loops + if self._global_shutdown_event is not None: + self._global_shutdown_event.set() diff --git a/src/lerobot/rollout/inference/sync.py b/src/lerobot/rollout/inference/sync.py new file mode 100644 index 000000000..2bb05b6ab --- /dev/null +++ b/src/lerobot/rollout/inference/sync.py @@ -0,0 +1,122 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Synchronous inference engine: inline policy call per control tick.""" + +from __future__ import annotations + +import logging +from contextlib import nullcontext +from copy import copy + +import torch + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference +from lerobot.processor import PolicyProcessorPipeline + +from .base import InferenceEngine + +logger = logging.getLogger(__name__) + + +# TODO(Steven): support relative-action policies. The per-tick flow refreshes +# ``RelativeActionsProcessorStep._last_state`` every call, so cached chunk +# actions popped on later ticks get reanchored to the *current* robot state and +# absolute targets drift through the chunk. Relative-action policies are +# rejected at context-build time today; RTC postprocesses the whole chunk and +# is unaffected. +# +# Candidate fix: drive the policy via ``predict_action_chunk`` and serve a +# local FIFO of postprocessed actions. Eliminates drift by construction and +# saves per-tick pre/post work, but bypasses ``select_action`` — needs +# fallbacks for SAC (raises), ACT temporal ensembling (ensembler lives in +# ``select_action``), and Diffusion-family (obs-history queues populated as a +# side effect of ``select_action``). + + +class SyncInferenceEngine(InferenceEngine): + """Inline synchronous inference: compute one action per call. + + ``get_action`` runs the full policy pipeline (pre/post-processor + + ``select_action``) on the given observation frame and returns a + CPU action tensor reordered to match the dataset action keys. + """ + + def __init__( + self, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + dataset_features: dict, + ordered_action_keys: list[str], + task: str, + device: str | None, + robot_type: str, + ) -> None: + self._policy = policy + self._preprocessor = preprocessor + self._postprocessor = postprocessor + self._dataset_features = dataset_features + self._ordered_action_keys = ordered_action_keys + self._task = task + self._device = torch.device(device or "cpu") + self._robot_type = robot_type + logger.info( + "SyncInferenceEngine initialized (device=%s, action_keys=%d)", + self._device, + len(ordered_action_keys), + ) + + def start(self) -> None: + """No background resources to start.""" + logger.info("SyncInferenceEngine started (inline mode — no background thread)") + + def stop(self) -> None: + """No background resources to stop.""" + logger.info("SyncInferenceEngine stopped") + + def reset(self) -> None: + """Reset the policy and pre/post-processors.""" + logger.info("Resetting sync inference state (policy + processors)") + self._policy.reset() + self._preprocessor.reset() + self._postprocessor.reset() + + def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: + """Run the full inference pipeline on ``obs_frame`` and return an action tensor.""" + if obs_frame is None: + return None + # Shallow copy is intentional: the caller (`send_next_action`) builds + # ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the + # tensor/array values are not shared with any other reader. + observation = copy(obs_frame) + autocast_ctx = ( + torch.autocast(device_type=self._device.type) + if self._device.type == "cuda" and self._policy.config.use_amp + else nullcontext() + ) + with torch.inference_mode(), autocast_ctx: + observation = prepare_observation_for_inference( + observation, self._device, self._task, self._robot_type + ) + observation = self._preprocessor(observation) + action = self._policy.select_action(observation) + action = self._postprocessor(action) + action_tensor = action.squeeze(0).cpu() + + # Reorder to match dataset action ordering so the caller can treat + # the returned tensor uniformly across backends. + action_dict = make_robot_action(action_tensor, self._dataset_features) + return torch.tensor([action_dict[k] for k in self._ordered_action_keys]) diff --git a/src/lerobot/rollout/ring_buffer.py b/src/lerobot/rollout/ring_buffer.py new file mode 100644 index 000000000..2c0a06301 --- /dev/null +++ b/src/lerobot/rollout/ring_buffer.py @@ -0,0 +1,112 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Memory-bounded ring buffer for the Highlight Reel rollout strategy.""" + +from __future__ import annotations + +from collections import deque + +import numpy as np +import torch + + +class RolloutRingBuffer: + """Fixed-capacity circular buffer for observation/action frames. + + Stores the last *N* seconds of telemetry in memory, bounded by both + time (``max_frames``) and memory (``max_memory_bytes``). When either + limit is reached the oldest frames are evicted. + + .. note:: + This class is **single-threaded**. ``append``/``drain``/``clear`` + must all be called from the same thread (the rollout main loop). + Concurrent access from a background thread will corrupt + ``_current_bytes`` accounting. + + Parameters + ---------- + max_seconds: + Maximum duration of buffered telemetry. + max_memory_mb: + Hard memory cap in MiB. Frames are evicted when the estimated + total size exceeds this. + fps: + Frames per second — used to convert ``max_seconds`` to a frame + count. + """ + + def __init__(self, max_seconds: float = 30.0, max_memory_mb: int = 2048, fps: float = 30.0) -> None: + self._max_frames = int(max_seconds * fps) + self._max_bytes = int(max_memory_mb * 1024 * 1024) + self._buffer: deque[dict] = deque(maxlen=self._max_frames) + self._current_bytes: int = 0 + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def append(self, frame: dict) -> None: + """Add *frame* to the buffer, evicting the oldest if at capacity.""" + frame_bytes = _estimate_frame_bytes(frame) + + # Evict oldest frames until we are under the memory cap + while self._current_bytes + frame_bytes > self._max_bytes and self._buffer: + evicted = self._buffer.popleft() + self._current_bytes -= _estimate_frame_bytes(evicted) + + self._buffer.append(frame) + self._current_bytes += frame_bytes + + def drain(self) -> list[dict]: + """Return all buffered frames and clear the buffer.""" + frames = list(self._buffer) + self._buffer.clear() + self._current_bytes = 0 + return frames + + def clear(self) -> None: + """Discard all buffered frames.""" + self._buffer.clear() + self._current_bytes = 0 + + def __len__(self) -> int: + return len(self._buffer) + + @property + def estimated_bytes(self) -> int: + """Estimated total byte size of all buffered frames.""" + return self._current_bytes + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _estimate_frame_bytes(frame: dict) -> int: + """Rough byte estimate for a single frame dictionary.""" + total = 0 + for v in frame.values(): + if isinstance(v, torch.Tensor): + # ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the + # memory cap is honoured even when frames hold unconverted tensors. + total += v.nelement() * v.element_size() + elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"): + total += v.nbytes + elif isinstance(v, (int, float)): + total += 8 + elif isinstance(v, (str, bytes)): + total += len(v) + return max(total, 1) # avoid zero-size frames diff --git a/src/lerobot/rollout/robot_wrapper.py b/src/lerobot/rollout/robot_wrapper.py new file mode 100644 index 000000000..44f744812 --- /dev/null +++ b/src/lerobot/rollout/robot_wrapper.py @@ -0,0 +1,79 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread-safe robot wrapper for concurrent observation/action access.""" + +from __future__ import annotations + +from threading import Lock +from typing import Any + +from lerobot.robots import Robot + + +class ThreadSafeRobot: + """Lock-protected wrapper around a :class:`Robot` for use with background threads. + + When RTC inference runs in a background thread while the main loop + executes actions, both threads may access the robot concurrently. + This wrapper serialises ``get_observation`` and ``send_action`` calls. + + Read-only properties are proxied without the lock since they don't + mutate hardware state. + """ + + def __init__(self, robot: Robot) -> None: + self._robot = robot + self._lock = Lock() + + # -- Lock-protected I/O -------------------------------------------------- + + def get_observation(self) -> dict[str, Any]: + with self._lock: + return self._robot.get_observation() + + def send_action(self, action: dict[str, Any] | Any) -> Any: + with self._lock: + return self._robot.send_action(action) + + # -- Read-only proxies (no lock needed) ----------------------------------- + + @property + def observation_features(self) -> dict: + return self._robot.observation_features + + @property + def action_features(self) -> dict: + return self._robot.action_features + + @property + def name(self) -> str: + return self._robot.name + + @property + def robot_type(self) -> str: + return self._robot.robot_type + + @property + def cameras(self): + return getattr(self._robot, "cameras", {}) + + @property + def is_connected(self) -> bool: + return self._robot.is_connected + + @property + def inner(self) -> Robot: + """Access the underlying robot (e.g. for connect/disconnect).""" + return self._robot diff --git a/src/lerobot/rollout/strategies/__init__.py b/src/lerobot/rollout/strategies/__init__.py new file mode 100644 index 000000000..554327073 --- /dev/null +++ b/src/lerobot/rollout/strategies/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rollout strategies — public API re-exports.""" + +from .base import BaseStrategy +from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action +from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy +from .factory import create_strategy +from .highlight import HighlightStrategy +from .sentry import SentryStrategy + +__all__ = [ + "BaseStrategy", + "DAggerEvents", + "DAggerPhase", + "DAggerStrategy", + "HighlightStrategy", + "RolloutStrategy", + "SentryStrategy", + "create_strategy", + "estimate_max_episode_seconds", + "safe_push_to_hub", + "send_next_action", +] diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py new file mode 100644 index 000000000..e47b65209 --- /dev/null +++ b/src/lerobot/rollout/strategies/base.py @@ -0,0 +1,85 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base rollout strategy: autonomous policy execution with no data recording.""" + +from __future__ import annotations + +import logging +import time + +from lerobot.utils.robot_utils import precise_sleep + +from ..context import RolloutContext +from .core import RolloutStrategy, send_next_action + +logger = logging.getLogger(__name__) + + +class BaseStrategy(RolloutStrategy): + """Autonomous policy rollout with no data recording. + + All actions flow through the ``robot_action_processor`` pipeline + before reaching the robot. + """ + + def setup(self, ctx: RolloutContext) -> None: + """Initialise the inference engine.""" + self._init_engine(ctx) + logger.info("Base strategy ready") + + def run(self, ctx: RolloutContext) -> None: + """Run the autonomous control loop until shutdown or duration expires.""" + engine = self._engine + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + interpolator = self._interpolator + + control_interval = interpolator.get_control_interval(cfg.fps) + + start_time = time.perf_counter() + engine.resume() + logger.info("Base strategy control loop started") + + while not ctx.runtime.shutdown_event.is_set(): + loop_start = time.perf_counter() + + if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: + logger.info("Duration limit reached (%.0fs)", cfg.duration) + break + + obs = robot.get_observation() + obs_processed = self._process_observation_and_notify(ctx.processors, obs) + + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + self._log_telemetry(obs_processed, action_dict, ctx.runtime) + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + else: + logger.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + def teardown(self, ctx: RolloutContext) -> None: + """Disconnect hardware and stop inference.""" + self._teardown_hardware( + ctx.hardware, + return_to_initial_position=ctx.runtime.cfg.return_to_initial_position, + ) + logger.info("Base strategy teardown complete") diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py new file mode 100644 index 000000000..9c897522f --- /dev/null +++ b/src/lerobot/rollout/strategies/core.py @@ -0,0 +1,304 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rollout strategy ABC and shared action-dispatch helper.""" + +from __future__ import annotations + +import abc +import logging +import time +from typing import TYPE_CHECKING + +from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB +from lerobot.utils.action_interpolator import ActionInterpolator +from lerobot.utils.constants import OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.visualization_utils import log_rerun_data + +from ..inference import InferenceEngine + +if TYPE_CHECKING: + from ..configs import RolloutStrategyConfig + from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext + +logger = logging.getLogger(__name__) + + +class RolloutStrategy(abc.ABC): + """Abstract base for rollout execution strategies. + + Each concrete strategy implements a self-contained control loop with + its own recording/interaction semantics. Strategies are mutually + exclusive — only one runs per session. + """ + + def __init__(self, config: RolloutStrategyConfig) -> None: + self.config = config + self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None + self._warmup_flushed: bool = False + self._cached_obs_processed: dict | None = None + + def _init_engine(self, ctx: RolloutContext) -> None: + """Attach the inference engine and action interpolator, then start the backend. + + Creates an :class:`ActionInterpolator` from the config's + ``interpolation_multiplier`` and starts the inference engine. + Call this from ``setup()`` so strategies share identical + initialisation without duplicating code. + """ + self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier) + self._engine = ctx.policy.inference + logger.info("Starting inference engine...") + self._engine.reset() + self._engine.start() + self._warmup_flushed = False + self._cached_obs_processed = None + logger.info("Inference engine started") + + def _process_observation_and_notify(self, processors: ProcessorContext, obs_raw: dict) -> dict: + """Run the observation processor and notify the engine — throttled to policy ticks. + + Callers are responsible for calling ``robot.get_observation()`` every loop + iteration so ``obs_raw`` stays fresh for the action post-processor. This + helper gates only the comparatively expensive bits — the processor pipeline + and ``engine.notify_observation`` — to fire when the interpolator signals + it needs a new action (once per ``interpolation_multiplier`` ticks). On + interpolated ticks the cached ``obs_processed`` is reused. + + With ``interpolation_multiplier == 1`` this is equivalent to the unthrottled + path: ``needs_new_action()`` is True every tick. + + The cache is implicitly invalidated whenever ``interpolator.reset()`` is + called (warmup completion, DAgger phase transitions back to AUTONOMOUS), + because reset makes ``needs_new_action()`` return True on the next call. + """ + if self._cached_obs_processed is None or self._interpolator.needs_new_action(): + obs_processed = processors.robot_observation_processor(obs_raw) + self._engine.notify_observation(obs_processed) + self._cached_obs_processed = obs_processed + return self._cached_obs_processed + + def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool: + """Handle torch.compile warmup phase. + + Returns ``True`` if the caller should ``continue`` (still warming + up). On the first post-warmup iteration the engine and + interpolator are reset so stale warmup state is discarded. + """ + engine = self._engine + interpolator = self._interpolator + if not use_torch_compile: + return False + if not engine.ready: + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + return True + if not self._warmup_flushed: + logger.info("Warmup complete — flushing stale state and resuming engine") + engine.reset() + interpolator.reset() + self._warmup_flushed = True + engine.resume() + return False + + def _teardown_hardware(self, hw: HardwareContext, return_to_initial_position: bool = True) -> None: + """Stop the inference engine, optionally return robot to initial position, and disconnect hardware.""" + if self._engine is not None: + logger.info("Stopping inference engine...") + self._engine.stop() + robot = hw.robot_wrapper.inner + if robot.is_connected: + if return_to_initial_position and hw.initial_position: + logger.info("Returning robot to initial position before shutdown...") + self._return_to_initial_position(hw) + elif not return_to_initial_position: + logger.info( + "Skipping return-to-initial-position (disabled by config); leaving robot in final pose." + ) + logger.info("Disconnecting robot...") + robot.disconnect() + teleop = hw.teleop + if teleop is not None and teleop.is_connected: + logger.info("Disconnecting teleoperator...") + teleop.disconnect() + + @staticmethod + def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None: + """Smoothly interpolate the robot back to its initial position.""" + robot = hw.robot_wrapper + target = hw.initial_position + try: + current_obs = robot.get_observation() + current_pos = {k: v for k, v in current_obs.items() if k in target} + steps = max(int(duration_s * fps), 1) + for step in range(1, steps + 1): + t = step / steps + interp = {} + for k in current_pos: + interp[k] = current_pos[k] * (1 - t) + target[k] * t + robot.send_action(interp) + precise_sleep(1 / fps) + except Exception as e: + logger.warning("Could not return to initial position: %s", e) + + @staticmethod + def _log_telemetry( + obs_processed: dict | None, + action_dict: dict | None, + runtime_ctx: RuntimeContext, + ) -> None: + """Log observation/action telemetry to Rerun if display_data is enabled.""" + cfg = runtime_ctx.cfg + if not cfg.display_data: + return + log_rerun_data( + observation=obs_processed, + action=action_dict, + compress_images=cfg.display_compressed_images, + ) + + @abc.abstractmethod + def setup(self, ctx: RolloutContext) -> None: + """Strategy-specific initialisation (keyboard listeners, buffers, etc.).""" + + @abc.abstractmethod + def run(self, ctx: RolloutContext) -> None: + """Main rollout loop. Returns when shutdown is requested or duration expires.""" + + @abc.abstractmethod + def teardown(self, ctx: RolloutContext) -> None: + """Cleanup: save dataset, stop threads, disconnect hardware.""" + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def safe_push_to_hub(dataset, tags=None, private=False) -> bool: + """Push dataset to hub, skipping if no episodes have been saved. + + Returns ``True`` if the push was attempted, ``False`` if skipped. + """ + if dataset.num_episodes == 0: + logger.warning("No episodes saved — skipping push to hub") + return False + dataset.push_to_hub(tags=tags, private=private) + return True + + +def estimate_max_episode_seconds( + dataset_features: dict, + fps: float, + target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB, +) -> float: + """Conservatively estimate how many seconds of video will exceed *target_size_mb*. + + Each camera produces its own video file, so the episode duration is + driven by the **slowest** camera to fill ``target_size_mb`` — i.e. + the one with the fewest pixels per frame (lowest bitrate). + + Uses a deliberately **low** bits-per-pixel estimate so the computed + duration is *longer* than reality. By the time the timer fires the + actual video file is guaranteed to have crossed the target size, + which aligns episode boundaries with the dataset's video-file + chunking — each ``push_to_hub`` uploads complete files rather than + re-uploading a still-growing one. + + The estimate ignores codec-specific settings (CRF, preset) on purpose: + we only need a rough lower bound on bitrate, not a precise prediction. + + Falls back to 300 s (5 min) when no video features are present. + """ + # 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of + # robot footage (real-world is typically 0.1 – 0.3 bpp). Under- + # estimating the bitrate over-estimates the time → the episode will be + # *larger* than target_size_mb when we save, which is what we want. + conservative_bpp = 0.1 + + # Collect per-camera pixel counts — each camera has its own video file. + camera_pixels = [] + for feat in dataset_features.values(): + if feat.get("dtype") == "video": + shape = feat.get("shape", ()) + + # (H, W, C) — bits-per-pixel is a per-spatial-pixel metric, + # so we exclude the channel dimension from the count. + if len(shape) == 3: + pixels = shape[0] * shape[1] + camera_pixels.append(pixels) + else: + raise ValueError(f"Unexpected video feature shape: {shape}") + + if not camera_pixels: + return 300.0 + + # Use the smallest camera: it produces the lowest bitrate and therefore + # takes the longest to reach the target — the conservative choice. + min_pixels = min(camera_pixels) + bits_per_frame = min_pixels * conservative_bpp + bytes_per_second = (bits_per_frame * fps) / 8 + + # Guard against division by zero just in case + if bytes_per_second <= 0: + return 300.0 + + return (target_size_mb * 1024 * 1024) / bytes_per_second + + +# --------------------------------------------------------------------------- +# Shared action-dispatch helper +# --------------------------------------------------------------------------- + + +def send_next_action( + obs_processed: dict, + obs_raw: dict, + ctx: RolloutContext, + interpolator: ActionInterpolator, +) -> dict | None: + """Dispatch the next action to the robot. + + Pulls the next action tensor from the inference engine, feeds the + interpolator, and sends the interpolated action through the + ``robot_action_processor`` to the robot. Works identically for + sync and async backends — the rollout strategy never needs to branch. + + Returns the action dict that was sent, or ``None`` if no action was + ready (e.g. empty async queue, interpolator not yet primed). + """ + engine = ctx.policy.inference + features = ctx.data.dataset_features + ordered_keys = ctx.data.ordered_action_keys + + if interpolator.needs_new_action(): + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_tensor = engine.get_action(obs_frame) + if action_tensor is not None: + interpolator.add(action_tensor.cpu()) + + interp = interpolator.get() + if interp is None: + return None + + if len(interp) != len(ordered_keys): + raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})") + action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)} + processed = ctx.processors.robot_action_processor((action_dict, obs_raw)) + ctx.hardware.robot_wrapper.send_action(processed) + return action_dict diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py new file mode 100644 index 000000000..da4b463fc --- /dev/null +++ b/src/lerobot/rollout/strategies/dagger.py @@ -0,0 +1,767 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DAgger rollout strategy: Human-in-the-Loop data collection. + +Implements the RaC paradigm (Recovery and Correction) for interactive +imitation learning. Alternates between autonomous policy execution and +human intervention via teleoperator. + +Input is controlled via either a keyboard or foot pedal, selected by +the ``input_device`` config field. Each device exposes three actions: + + 1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED). + 2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING). + 3. **upload** — Push dataset to hub on demand (corrections-only mode). + ESC (keyboard only) — Stop session. + +Recording modes: + ``record_autonomous=True``: Sentry-like continuous recording with + time-based episode rotation. Both autonomous and correction + frames are recorded; corrections tagged ``intervention=True``. + ``record_autonomous=False``: Only correction windows are recorded. + Each correction (start to stop) becomes one episode. + +Teleoperator expectations: + The user is responsible for keeping the leader arm aligned with the + follower arm at the moment a correction begins. Programmatic motor + handover (``enable_torque`` / ``disable_torque`` / ``write_goal_positions``) + is intentionally not invoked here — see the TODO in + :func:`DAggerStrategy._apply_transition` for the open design decision. +""" + +from __future__ import annotations + +import contextlib +import enum +import logging +import os +import sys +import time +from concurrent.futures import Future, ThreadPoolExecutor +from threading import Event, Lock +from typing import Any + +import numpy as np + +from lerobot.common.control_utils import is_headless +from lerobot.datasets import VideoEncodingManager +from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB +from lerobot.teleoperators import Teleoperator +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.import_utils import _pynput_available +from lerobot.utils.pedal import start_pedal_listener +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import log_say + +from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig +from ..context import RolloutContext +from ..robot_wrapper import ThreadSafeRobot +from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action + +PYNPUT_AVAILABLE = _pynput_available +keyboard = None +if PYNPUT_AVAILABLE: + try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + PYNPUT_AVAILABLE = False + else: + from pynput import keyboard + except Exception as e: + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# DAgger state machine +# --------------------------------------------------------------------------- + + +class DAggerPhase(enum.Enum): + """Observable phases of a DAgger episode.""" + + AUTONOMOUS = "autonomous" # Policy driving + PAUSED = "paused" # Engine paused, teleop aligned, awaiting input + CORRECTING = "correcting" # Human driving via teleop, recording interventions + + +# Valid (current_phase, event) -> next_phase +_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = { + (DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED, + (DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS, + (DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING, + (DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED, +} + + +class DAggerEvents: + """Thread-safe container for DAgger input device events. + + The keyboard/pedal threads write transition requests; the main loop + consumes them. + """ + + def __init__(self) -> None: + self._lock = Lock() + self._phase = DAggerPhase.AUTONOMOUS + self._pending_transition: str | None = None + + # Session-level flags + self.stop_recording = Event() + self.upload_requested = Event() + + # -- Thread-safe phase access ------------------------------------------ + + @property + def phase(self) -> DAggerPhase: + """Current phase of the DAgger state machine.""" + with self._lock: + return self._phase + + @phase.setter + def phase(self, value: DAggerPhase) -> None: + with self._lock: + self._phase = value + + def request_transition(self, event: str) -> None: + """Request a phase transition (called from keyboard/pedal threads). + + Only enqueues the request if it corresponds to a valid transition + from the current phase, preventing impossible state changes. + """ + with self._lock: + if (self._phase, event) in _DAGGER_TRANSITIONS: + self._pending_transition = event + + def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None: + """Consume a pending transition (called from main loop).""" + with self._lock: + if self._pending_transition is None: + return None + key = (self._phase, self._pending_transition) + self._pending_transition = None + new_phase = _DAGGER_TRANSITIONS.get(key) + if new_phase is None: + return None + old_phase = self._phase + self._phase = new_phase + return old_phase, new_phase + + def reset(self) -> None: + """Reset all transient state for a fresh session.""" + with self._lock: + self._phase = DAggerPhase.AUTONOMOUS + self._pending_transition = None + self.upload_requested.clear() + + +# --------------------------------------------------------------------------- +# Teleoperator helpers +# --------------------------------------------------------------------------- + + +# TODO(Steven): re-enable programmatic teleop alignment once we decide whether +# to enforce motor-control methods on every Teleoperator. Until then the user +# is responsible for moving the leader arm to the follower's pose at the moment +# a correction begins. +def _teleop_smooth_move_to( + teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50 +) -> None: + """Smoothly move teleop to target position via linear interpolation. + + Requires the teleoperator to support motor control methods + (``enable_torque``, ``write_goal_positions``, ``get_action``). + """ + teleop.enable_torque() + current = teleop.get_action() + steps = max(int(duration_s * fps), 1) + + for step in range(steps + 1): + t = step / steps + interp = {} + for k in current: + if k in target_pos: + interp[k] = current[k] * (1 - t) + target_pos[k] * t + else: + interp[k] = current[k] + teleop.write_goal_positions(interp) + time.sleep(1 / fps) + + +# --------------------------------------------------------------------------- +# Input device handlers +# --------------------------------------------------------------------------- + + +def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): + """Initialise keyboard listener with DAgger 3-key controls. + + Returns the pynput Listener (or ``None`` in headless mode or when + pynput is unavailable). + """ + if not PYNPUT_AVAILABLE or is_headless(): + logger.warning("Headless environment or pynput unavailable — keyboard controls disabled") + return None + + # Map config key names to pynput Key objects for special keys + special_keys = { + "space": keyboard.Key.space, + "tab": keyboard.Key.tab, + "enter": keyboard.Key.enter, + } + + def _resolve_key(key) -> str | None: + """Resolve a pynput key event to a config-comparable string.""" + if key == keyboard.Key.esc: + return "esc" + for name, pynput_key in special_keys.items(): + if key == pynput_key: + return name + if hasattr(key, "char") and key.char: + return key.char + return None + + # Build mapping: resolved key string -> DAgger event name + key_to_event = { + cfg.pause_resume: "pause_resume", + cfg.correction: "correction", + } + + def on_press(key): + try: + resolved = _resolve_key(key) + if resolved is None: + return + if resolved == "esc": + logger.info("Stop recording...") + events.stop_recording.set() + return + if resolved in key_to_event: + events.request_transition(key_to_event[resolved]) + if resolved == cfg.upload: + events.upload_requested.set() + except Exception as e: + logger.debug("Key error: %s", e) + + listener = keyboard.Listener(on_press=on_press) + listener.start() + logger.info( + "DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)", + cfg.pause_resume, + cfg.correction, + cfg.upload, + ) + return listener + + +def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig): + """Initialise foot pedal listener with DAgger 3-pedal controls. + + Returns the pedal listener thread (or ``None`` if evdev is unavailable). + """ + code_to_event = { + cfg.pause_resume: "pause_resume", + cfg.correction: "correction", + } + + def on_press(code: str) -> None: + if code in code_to_event: + events.request_transition(code_to_event[code]) + if code == cfg.upload: + events.upload_requested.set() + + logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path) + return start_pedal_listener(on_press, device_path=cfg.device_path) + + +# --------------------------------------------------------------------------- +# DAgger Strategy +# --------------------------------------------------------------------------- + + +class DAggerStrategy(RolloutStrategy): + """Human-in-the-Loop data collection with intervention tagging. + + State machine:: + + AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED + --(key1)--> AUTONOMOUS + + Recording modes: + ``record_autonomous=True``: Sentry-like continuous recording with + time-based episode rotation. Intervention frames tagged True. + ``record_autonomous=False``: Only correction windows recorded. + Each correction = one episode. Upload on demand via key3. + """ + + config: DAggerStrategyConfig + + def __init__(self, config: DAggerStrategyConfig): + super().__init__(config) + self._listener = None + self._pedal_thread = None + self._events = DAggerEvents() + self._push_executor: ThreadPoolExecutor | None = None + self._pending_push: Future | None = None + self._needs_push = Event() + self._episode_lock = Lock() + + def setup(self, ctx: RolloutContext) -> None: + """Initialise the inference engine and input device listener.""" + self._init_engine(ctx) + self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push") + target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB + self._episode_duration_s = estimate_max_episode_seconds( + ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb + ) + + if self.config.input_device == "keyboard": + self._listener = _init_dagger_keyboard(self._events, self.config.keyboard) + else: + self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal) + + record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only" + logger.info( + "DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)", + self.config.input_device, + self.config.num_episodes, + record_mode, + self._episode_duration_s, + ) + + def run(self, ctx: RolloutContext) -> None: + """Run DAgger episodes with human-in-the-loop intervention.""" + if self.config.record_autonomous: + self._run_continuous(ctx) + else: + self._run_corrections_only(ctx) + + def teardown(self, ctx: RolloutContext) -> None: + """Stop listeners, finalise the dataset, and disconnect hardware.""" + play_sounds = ctx.runtime.cfg.play_sounds + logger.info("Stopping DAgger recording") + log_say("Stopping DAgger recording", play_sounds) + + if self._listener is not None and not is_headless(): + logger.info("Stopping keyboard listener") + self._listener.stop() + + # Flush any queued/running push cleanly + if self._push_executor is not None: + logger.info("Shutting down push executor (waiting for pending pushes)...") + self._push_executor.shutdown(wait=True) + self._push_executor = None + + if ctx.data.dataset is not None: + logger.info("Finalizing dataset...") + ctx.data.dataset.finalize() + if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + logger.info("Pushing final dataset to hub...") + if safe_push_to_hub( + ctx.data.dataset, + tags=ctx.runtime.cfg.dataset.tags, + private=ctx.runtime.cfg.dataset.private, + ): + logger.info("Dataset uploaded to hub") + log_say("Dataset uploaded to hub", play_sounds) + + self._teardown_hardware( + ctx.hardware, + return_to_initial_position=ctx.runtime.cfg.return_to_initial_position, + ) + logger.info("DAgger strategy teardown complete") + + # ------------------------------------------------------------------ + # Continuous recording mode (record_autonomous=True) + # ------------------------------------------------------------------ + + def _run_continuous(self, ctx: RolloutContext) -> None: + """Sentry-like continuous recording with intervention tagging. + + Episodes are auto-rotated every ``episode_time_s`` seconds and + uploaded in the background every ``upload_every_n_episodes`` episodes. + Both autonomous and correction frames are recorded; corrections are + tagged with ``intervention=True``. + """ + engine = self._engine + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + teleop = ctx.hardware.teleop + dataset = ctx.data.dataset + events = self._events + interpolator = self._interpolator + features = ctx.data.dataset_features + + control_interval = interpolator.get_control_interval(cfg.fps) + record_stride = max(1, cfg.interpolation_multiplier) + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + play_sounds = cfg.play_sounds + + engine.reset() + interpolator.reset() + events.reset() + # TODO(Steven): re-enable once Teleoperator motor-control methods are + # standardised; until then the user pre-aligns the leader by hand. + # teleop.disable_torque() + engine.resume() + + last_action: dict[str, Any] | None = None + record_tick = 0 + start_time = time.perf_counter() + episode_start = time.perf_counter() + episodes_since_push = 0 + episode_duration_s = self._episode_duration_s + logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s) + + with VideoEncodingManager(dataset): + try: + while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set(): + loop_start = time.perf_counter() + + if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: + logger.info("Duration limit reached (%.0fs)", cfg.duration) + break + + # Process transitions + transition = events.consume_transition() + if transition is not None: + old_phase, new_phase = transition + self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop) + last_action = None + + phase = events.phase + obs = robot.get_observation() + + # --- CORRECTING: human teleop control --- + # TODO(Steven): teleop runs at the same FPS as the policy. To + # decouple the two, sample teleop at its native rate and + # interpolate to the control loop's tick rate. + if phase == DAggerPhase.CORRECTING: + obs_processed = ctx.processors.robot_observation_processor(obs) + teleop_action = teleop.get_action() + processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs)) + robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs)) + robot.send_action(robot_action_to_send) + last_action = robot_action_to_send + self._log_telemetry(obs_processed, processed_teleop, ctx.runtime) + if record_tick % record_stride == 0: + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) + frame = { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([True], dtype=bool), + } + dataset.add_frame(frame) + record_tick += 1 + + # --- PAUSED: hold position --- + elif phase == DAggerPhase.PAUSED: + if last_action: + robot.send_action(last_action) + + # --- AUTONOMOUS: policy control --- + else: + obs_processed = self._process_observation_and_notify(ctx.processors, obs) + + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + if action_dict is not None: + self._log_telemetry(obs_processed, action_dict, ctx.runtime) + last_action = ctx.processors.robot_action_processor((action_dict, obs)) + if record_tick % record_stride == 0: + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) + frame = { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([False], dtype=bool), + } + dataset.add_frame(frame) + record_tick += 1 + + # Episode rotation derived from the video file-size target. + # Saving is deferred while a correction is ongoing so the + # episode boundary lands on a clean autonomous frame. + elapsed = time.perf_counter() - episode_start + if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING: + with self._episode_lock: + dataset.save_episode() + episodes_since_push += 1 + self._needs_push.set() + logger.info( + "Episode saved (total: %d, elapsed: %.1fs)", + dataset.num_episodes, + elapsed, + ) + log_say(f"Episode {dataset.num_episodes} saved", play_sounds) + + if episodes_since_push >= self.config.upload_every_n_episodes: + self._background_push(dataset, cfg) + episodes_since_push = 0 + + episode_start = time.perf_counter() + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + else: + logger.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + finally: + logger.info("DAgger continuous control loop ended — pausing engine") + engine.pause() + # TODO(Steven): re-enable once Teleoperator motor-control methods + # are standardised across all teleop implementations. + # teleop.disable_torque() + with contextlib.suppress(Exception): + with self._episode_lock: + dataset.save_episode() + self._needs_push.set() + logger.info("Final in-progress episode saved") + + # ------------------------------------------------------------------ + # Corrections-only mode (record_autonomous=False) + # ------------------------------------------------------------------ + + def _run_corrections_only(self, ctx: RolloutContext) -> None: + """Record only human correction windows. Each correction = one episode. + + The policy runs autonomously without recording. When the user + pauses and starts a correction, frames are recorded with + ``intervention=True``. Stopping the correction saves the episode. + The dataset can be uploaded on demand via the upload key/pedal. + """ + engine = self._engine + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + teleop = ctx.hardware.teleop + dataset = ctx.data.dataset + events = self._events + interpolator = self._interpolator + features = ctx.data.dataset_features + + control_interval = interpolator.get_control_interval(cfg.fps) + record_stride = max(1, cfg.interpolation_multiplier) + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + play_sounds = cfg.play_sounds + + engine.reset() + interpolator.reset() + events.reset() + # TODO(Steven): re-enable once Teleoperator motor-control methods are + # standardised; until then the user pre-aligns the leader by hand. + # teleop.disable_torque() + engine.resume() + + last_action: dict[str, Any] | None = None + start_time = time.perf_counter() + record_tick = 0 + recorded = 0 + logger.info( + "DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes + ) + + with VideoEncodingManager(dataset): + try: + while ( + recorded < self.config.num_episodes + and not events.stop_recording.is_set() + and not ctx.runtime.shutdown_event.is_set() + ): + loop_start = time.perf_counter() + + if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: + logger.info("Duration limit reached (%.0fs)", cfg.duration) + break + + # Process transitions + transition = events.consume_transition() + if transition is not None: + old_phase, new_phase = transition + self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop) + last_action = None + + # Correction ended -> save episode (blocking if not streaming) + if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED: + with self._episode_lock: + dataset.save_episode() + recorded += 1 + self._needs_push.set() + logger.info( + "Correction %d/%d saved", + recorded, + self.config.num_episodes, + ) + log_say(f"Correction {recorded} saved", play_sounds) + + # On-demand upload + if events.upload_requested.is_set(): + events.upload_requested.clear() + logger.info("Upload requested by user") + self._background_push(dataset, cfg) + + phase = events.phase + obs = robot.get_observation() + + # --- CORRECTING: human teleop control + recording --- + # TODO(Steven): teleop runs at the same FPS as the policy. To + # decouple the two, sample teleop at its native rate and + # interpolate to the control loop's tick rate. + if phase == DAggerPhase.CORRECTING: + obs_processed = ctx.processors.robot_observation_processor(obs) + teleop_action = teleop.get_action() + processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs)) + robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs)) + robot.send_action(robot_action_to_send) + last_action = robot_action_to_send + self._log_telemetry(obs_processed, processed_teleop, ctx.runtime) + + if record_tick % record_stride == 0: + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) + dataset.add_frame( + { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([True], dtype=bool), + } + ) + record_tick += 1 + + # --- PAUSED: hold position --- + elif phase == DAggerPhase.PAUSED: + if last_action: + robot.send_action(last_action) + + # --- AUTONOMOUS: policy control (no recording) --- + else: + obs_processed = self._process_observation_and_notify(ctx.processors, obs) + + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + if action_dict is not None: + self._log_telemetry(obs_processed, action_dict, ctx.runtime) + last_action = ctx.processors.robot_action_processor((action_dict, obs)) + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + else: + logger.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + finally: + logger.info("DAgger corrections-only loop ended — pausing engine") + engine.pause() + # TODO(Steven): re-enable once Teleoperator motor-control methods + # are standardised across all teleop implementations. + # teleop.disable_torque() + with contextlib.suppress(Exception): + with self._episode_lock: + dataset.save_episode() + self._needs_push.set() + logger.info("Final in-progress episode saved") + + # ------------------------------------------------------------------ + # State-machine transition side-effects + # ------------------------------------------------------------------ + + @staticmethod + def _apply_transition( + old_phase: DAggerPhase, + new_phase: DAggerPhase, + engine, + interpolator, + robot: ThreadSafeRobot, + teleop: Teleoperator, + ) -> None: + """Execute side-effects for a validated phase transition.""" + logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value) + if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED: + logger.info("Pausing engine — robot holds position") + engine.pause() + obs = robot.get_observation() + _robot_pos = { + k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features + } + # TODO(Steven): once Teleoperator motor-control methods are + # standardised, drive the leader to the follower's pose here so the + # operator does not need to pre-align the arm by hand. Until then + # the user is responsible for the alignment. + # _teleop_smooth_move_to(teleop, _robot_pos, duration_s=2.0, fps=50) + + elif new_phase == DAggerPhase.CORRECTING: + logger.info("Entering correction mode — human teleop control") + # TODO(Steven): re-enable once Teleoperator motor-control methods + # are standardised across all teleop implementations. + # teleop.disable_torque() + + elif new_phase == DAggerPhase.AUTONOMOUS: + logger.info("Resuming autonomous mode — resetting engine and interpolator") + interpolator.reset() + engine.reset() + engine.resume() + + # ------------------------------------------------------------------ + # Background push (shared by both modes) + # ------------------------------------------------------------------ + + def _background_push(self, dataset, cfg) -> None: + """Queue a Hub push on the single-worker executor. + + The executor's max_workers=1 guarantees at most one push runs at + a time; submitted tasks are queued rather than dropped. Pushes + are blocked while the operator is mid-correction to avoid + uploading a partially-recorded episode. + """ + if self._push_executor is None: + return + + if self._events.phase == DAggerPhase.CORRECTING: + logger.info("Skipping push — correction in progress") + return + + if self._pending_push is not None and not self._pending_push.done(): + logger.info("Previous push still in progress; queueing next") + + def _push(): + try: + with self._episode_lock: + if safe_push_to_hub( + dataset, + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ): + self._needs_push.clear() + logger.info("Background push to hub complete") + except Exception as e: + logger.error("Background push failed: %s", e) + + self._pending_push = self._push_executor.submit(_push) + logger.info("Background push task submitted") diff --git a/src/lerobot/rollout/strategies/factory.py b/src/lerobot/rollout/strategies/factory.py new file mode 100644 index 000000000..8a9727769 --- /dev/null +++ b/src/lerobot/rollout/strategies/factory.py @@ -0,0 +1,45 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategy factory: config type-name → strategy class dispatch.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .base import BaseStrategy +from .core import RolloutStrategy +from .dagger import DAggerStrategy +from .highlight import HighlightStrategy +from .sentry import SentryStrategy + +if TYPE_CHECKING: + from ..configs import RolloutStrategyConfig + + +def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: + """Instantiate the appropriate strategy from a config object. + + Dispatches on ``config.type`` (the name registered via + ``draccus.ChoiceRegistry``). + """ + if config.type == "base": + return BaseStrategy(config) + if config.type == "sentry": + return SentryStrategy(config) + if config.type == "highlight": + return HighlightStrategy(config) + if config.type == "dagger": + return DAggerStrategy(config) + raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger") diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py new file mode 100644 index 000000000..baff70da7 --- /dev/null +++ b/src/lerobot/rollout/strategies/highlight.py @@ -0,0 +1,283 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Highlight Reel strategy: on-demand recording via ring buffer.""" + +from __future__ import annotations + +import contextlib +import logging +import os +import sys +import time +from concurrent.futures import Future, ThreadPoolExecutor +from threading import Event as ThreadingEvent, Lock + +from lerobot.common.control_utils import is_headless +from lerobot.datasets import VideoEncodingManager +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.import_utils import _pynput_available, require_package +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import log_say + +from ..configs import HighlightStrategyConfig +from ..context import RolloutContext +from ..ring_buffer import RolloutRingBuffer +from .core import RolloutStrategy, safe_push_to_hub, send_next_action + +PYNPUT_AVAILABLE = _pynput_available +keyboard = None +if PYNPUT_AVAILABLE: + try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + PYNPUT_AVAILABLE = False + else: + from pynput import keyboard + except Exception as e: + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") + +logger = logging.getLogger(__name__) + + +class HighlightStrategy(RolloutStrategy): + """Autonomous rollout with on-demand recording via ring buffer. + + The robot runs autonomously while a memory-bounded ring buffer + captures continuous telemetry. When the user presses the save key: + + 1. The ring buffer is flushed to the dataset (last *Z* seconds). + 2. Live recording continues until the save key is pressed again. + 3. The episode is saved and the ring buffer resumes capturing. + + Requires ``streaming_encoding=True`` (enforced in config validation) + so that ``dataset.add_frame`` is a non-blocking queue put — flushing + the entire ring buffer in one tick must not stall the control loop. + """ + + config: HighlightStrategyConfig + + def __init__(self, config: HighlightStrategyConfig): + super().__init__(config) + require_package("pynput", extra="pynput-dep") + self._ring: RolloutRingBuffer | None = None + self._listener = None + self._save_requested = ThreadingEvent() + self._recording_live = ThreadingEvent() + self._push_requested = ThreadingEvent() + self._push_executor: ThreadPoolExecutor | None = None + self._pending_push: Future | None = None + self._episode_lock = Lock() + + def setup(self, ctx: RolloutContext) -> None: + """Initialise the inference engine, ring buffer, and keyboard listener.""" + self._init_engine(ctx) + + self._ring = RolloutRingBuffer( + max_seconds=self.config.ring_buffer_seconds, + max_memory_mb=self.config.ring_buffer_max_memory_mb, + fps=ctx.runtime.cfg.fps, + ) + + self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push") + logger.info( + "Ring buffer initialized (max_seconds=%.0f, max_memory=%.0fMB)", + self.config.ring_buffer_seconds, + self.config.ring_buffer_max_memory_mb, + ) + self._setup_keyboard(ctx.runtime.shutdown_event) + logger.info( + "Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')", + self.config.ring_buffer_seconds, + self.config.save_key, + self.config.push_key, + ) + + def run(self, ctx: RolloutContext) -> None: + """Run the autonomous loop, buffering frames and recording on demand.""" + engine = self._engine + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + dataset = ctx.data.dataset + ring = self._ring + interpolator = self._interpolator + features = ctx.data.dataset_features + + control_interval = interpolator.get_control_interval(cfg.fps) + + engine.resume() + play_sounds = cfg.play_sounds + + start_time = time.perf_counter() + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + logger.info("Highlight strategy recording started (press '%s' to save)", self.config.save_key) + + with VideoEncodingManager(dataset): + try: + while not ctx.runtime.shutdown_event.is_set(): + loop_start = time.perf_counter() + + if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: + logger.info("Duration limit reached (%.0fs)", cfg.duration) + break + + obs = robot.get_observation() + obs_processed = self._process_observation_and_notify(ctx.processors, obs) + + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + + if action_dict is not None: + self._log_telemetry(obs_processed, action_dict, ctx.runtime) + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": task_str} + + # NOTE: ``is_set()`` then ``clear()`` is not atomic + # against the keyboard thread setting the flag again + # in between — but that is benign: we lose at most one + # toggle, processed on the next iteration. + if self._save_requested.is_set(): + self._save_requested.clear() + if not self._recording_live.is_set(): + logger.info( + "Flushing ring buffer (%d frames) + starting live recording", + len(ring), + ) + for buffered_frame in ring.drain(): + dataset.add_frame(buffered_frame) + self._recording_live.set() + else: + dataset.add_frame(frame) + with self._episode_lock: + dataset.save_episode() + logger.info("Episode saved (total: %d)", dataset.num_episodes) + log_say( + f"Episode {dataset.num_episodes} saved", + play_sounds, + ) + self._recording_live.clear() + continue # frame already consumed — skip ring.append + + if self._push_requested.is_set(): + self._push_requested.clear() + logger.info("Push requested by user") + self._background_push(dataset, cfg) + + if self._recording_live.is_set(): + dataset.add_frame(frame) + else: + ring.append(frame) + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + else: + logger.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + finally: + logger.info("Highlight control loop ended") + if self._recording_live.is_set(): + logger.info("Saving in-progress live episode") + with contextlib.suppress(Exception), self._episode_lock: + dataset.save_episode() + + def teardown(self, ctx: RolloutContext) -> None: + """Stop listeners, finalise the dataset, and disconnect hardware.""" + play_sounds = ctx.runtime.cfg.play_sounds + logger.info("Stopping highlight recording") + log_say("Stopping highlight recording", play_sounds) + + if self._listener is not None: + logger.info("Stopping keyboard listener") + self._listener.stop() + + if self._push_executor is not None: + logger.info("Shutting down push executor (waiting for pending pushes)...") + self._push_executor.shutdown(wait=True) + self._push_executor = None + + if ctx.data.dataset is not None: + logger.info("Finalizing dataset...") + ctx.data.dataset.finalize() + if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + logger.info("Pushing final dataset to hub...") + if safe_push_to_hub( + ctx.data.dataset, + tags=ctx.runtime.cfg.dataset.tags, + private=ctx.runtime.cfg.dataset.private, + ): + logger.info("Dataset uploaded to hub") + log_say("Dataset uploaded to hub", play_sounds) + + self._teardown_hardware( + ctx.hardware, + return_to_initial_position=ctx.runtime.cfg.return_to_initial_position, + ) + logger.info("Highlight strategy teardown complete") + + def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None: + """Set up keyboard listener for save and push keys.""" + if is_headless(): + logger.warning("Headless environment — highlight keys unavailable") + return + + try: + save_key = self.config.save_key + push_key = self.config.push_key + + def on_press(key): + with contextlib.suppress(Exception): + if hasattr(key, "char") and key.char == save_key: + self._save_requested.set() + elif hasattr(key, "char") and key.char == push_key: + self._push_requested.set() + elif key == keyboard.Key.esc: + self._save_requested.clear() + shutdown_event.set() + + self._listener = keyboard.Listener(on_press=on_press) + self._listener.start() + logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key) + except ImportError: + logger.warning("pynput not available — keyboard listener disabled") + + def _background_push(self, dataset, cfg) -> None: + """Queue a Hub push on the single-worker executor.""" + if self._push_executor is None: + return + + if self._pending_push is not None and not self._pending_push.done(): + logger.info("Previous push still in progress; queueing next") + + def _push(): + try: + with self._episode_lock: + if safe_push_to_hub( + dataset, + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ): + logger.info("Background push to hub complete") + except Exception as e: + logger.error("Background push failed: %s", e) + + self._pending_push = self._push_executor.submit(_push) + logger.info("Background push task submitted") diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py new file mode 100644 index 000000000..61e38aa68 --- /dev/null +++ b/src/lerobot/rollout/strategies/sentry.py @@ -0,0 +1,231 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sentry rollout strategy: continuous autonomous recording with auto-upload.""" + +from __future__ import annotations + +import contextlib +import logging +import time +from concurrent.futures import Future, ThreadPoolExecutor +from threading import Event, Lock + +from lerobot.datasets import VideoEncodingManager +from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import log_say + +from ..configs import SentryStrategyConfig +from ..context import RolloutContext +from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action + +logger = logging.getLogger(__name__) + + +class SentryStrategy(RolloutStrategy): + """Continuous autonomous rollout with always-on recording. + + Episode duration is derived from camera resolution, FPS, and + ``DEFAULT_VIDEO_FILE_SIZE_IN_MB`` so that each saved episode + produces a video file that has crossed the chunk-size boundary. + This keeps ``push_to_hub`` efficient — it uploads complete video + files rather than re-uploading a still-growing one. + + The dataset is pushed to the Hub via a bounded single-worker executor + so no push is ever silently dropped and exactly one push runs at a + time. + + Policy state (hidden state, RTC queue) intentionally persists across + episode boundaries — Sentry slices one continuous rollout, the robot + does not reset between slices. + + Requires ``streaming_encoding=True`` (enforced in config validation) + to prevent disk I/O from blocking the control loop. + """ + + config: SentryStrategyConfig + + def __init__(self, config: SentryStrategyConfig): + super().__init__(config) + self._push_executor: ThreadPoolExecutor | None = None + self._pending_push: Future | None = None + self._needs_push = Event() + self._episode_lock = Lock() + + def setup(self, ctx: RolloutContext) -> None: + """Initialise the inference engine and background push executor.""" + self._init_engine(ctx) + self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push") + target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB + self._episode_duration_s = estimate_max_episode_seconds( + ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb + ) + logger.info( + "Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)", + self._episode_duration_s, + self.config.upload_every_n_episodes, + ) + + def run(self, ctx: RolloutContext) -> None: + """Run the continuous recording loop with automatic episode rotation.""" + engine = self._engine + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + dataset = ctx.data.dataset + interpolator = self._interpolator + features = ctx.data.dataset_features + + control_interval = interpolator.get_control_interval(cfg.fps) + + engine.resume() + play_sounds = cfg.play_sounds + episode_duration_s = self._episode_duration_s + + start_time = time.perf_counter() + episode_start = time.perf_counter() + episodes_since_push = 0 + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + logger.info("Sentry recording started (episode_duration=%.0fs)", episode_duration_s) + + with VideoEncodingManager(dataset): + try: + while not ctx.runtime.shutdown_event.is_set(): + loop_start = time.perf_counter() + + if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: + logger.info("Duration limit reached (%.0fs)", cfg.duration) + break + + obs = robot.get_observation() + obs_processed = self._process_observation_and_notify(ctx.processors, obs) + + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + + if action_dict is not None: + self._log_telemetry(obs_processed, action_dict, ctx.runtime) + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": task_str} + # ``add_frame`` writes to the in-progress episode buffer; the + # background pusher only ever touches *finalised* episode + # artifacts on disk. The two operate on disjoint state, so + # ``add_frame`` does not need ``_episode_lock``. + dataset.add_frame(frame) + + # Episode rotation derived from video file-size target. + # The duration is a conservative estimate so the actual + # video has crossed DEFAULT_VIDEO_FILE_SIZE_IN_MB by now, + # keeping push_to_hub efficient (uploads complete files). + elapsed = time.perf_counter() - episode_start + if elapsed >= episode_duration_s: + # ``save_episode`` finalises the in-progress episode and + # flushes it to disk; ``_episode_lock`` serialises this with + # ``push_to_hub`` (run in the background executor) so the + # pusher never reads a half-written episode. + with self._episode_lock: + dataset.save_episode() + episodes_since_push += 1 + self._needs_push.set() + logger.info( + "Episode saved (total: %d, elapsed: %.1fs)", + dataset.num_episodes, + elapsed, + ) + log_say(f"Episode {dataset.num_episodes} saved", play_sounds) + + if episodes_since_push >= self.config.upload_every_n_episodes: + self._background_push(dataset, cfg) + episodes_since_push = 0 + + episode_start = time.perf_counter() + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + else: + logger.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + finally: + logger.info("Sentry control loop ended — saving final episode") + with contextlib.suppress(Exception): + with self._episode_lock: + dataset.save_episode() + self._needs_push.set() + + def teardown(self, ctx: RolloutContext) -> None: + """Flush pending pushes, finalise the dataset, and disconnect hardware.""" + play_sounds = ctx.runtime.cfg.play_sounds + logger.info("Stopping sentry recording") + log_say("Stopping sentry recording", play_sounds) + + # Flush any queued/running push cleanly. + if self._push_executor is not None: + logger.info("Shutting down push executor (waiting for pending pushes)...") + self._push_executor.shutdown(wait=True) + self._push_executor = None + + if ctx.data.dataset is not None: + logger.info("Finalizing dataset...") + ctx.data.dataset.finalize() + if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + logger.info("Pushing final dataset to hub...") + if safe_push_to_hub( + ctx.data.dataset, + tags=ctx.runtime.cfg.dataset.tags, + private=ctx.runtime.cfg.dataset.private, + ): + logger.info("Dataset uploaded to hub") + log_say("Dataset uploaded to hub", play_sounds) + + self._teardown_hardware( + ctx.hardware, + return_to_initial_position=ctx.runtime.cfg.return_to_initial_position, + ) + logger.info("Sentry strategy teardown complete") + + def _background_push(self, dataset, cfg) -> None: + """Queue a Hub push on the single-worker executor. + + The executor's max_workers=1 guarantees at most one push runs at + a time; submitted tasks are queued rather than dropped. + """ + if self._push_executor is None: + return + + if self._pending_push is not None and not self._pending_push.done(): + logger.info("Previous push still in progress; queueing next") + + def _push(): + try: + with self._episode_lock: + if safe_push_to_hub( + dataset, + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ): + self._needs_push.clear() + logger.info("Background push to hub complete") + except Exception as e: + logger.error("Background push failed: %s", e) + + self._pending_push = self._push_executor.submit(_push) + logger.info("Background push task submitted") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 50b41c69d..129696bd3 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -13,70 +13,62 @@ # limitations under the License. """ -Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy. +Records a dataset via teleoperation. This is a pure data-collection +tool — no policy inference. For deploying trained policies, use +``lerobot-rollout`` instead. Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras) Example: ```shell -lerobot-record \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ - --robot.id=black \ - --dataset.repo_id=/ \ - --dataset.num_episodes=2 \ - --dataset.single_task="Grab the cube" \ - --dataset.streaming_encoding=true \ - --dataset.encoder_threads=2 \ +lerobot-record \\ + --robot.type=so100_follower \\ + --robot.port=/dev/tty.usbmodem58760431541 \\ + --robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\ + --robot.id=black \\ + --teleop.type=so100_leader \\ + --teleop.port=/dev/tty.usbmodem58760431551 \\ + --teleop.id=blue \\ + --dataset.repo_id=/ \\ + --dataset.num_episodes=2 \\ + --dataset.single_task="Grab the cube" \\ + --dataset.streaming_encoding=true \\ + --dataset.encoder_threads=2 \\ --display_data=true - # <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \ - # --dataset.vcodec=h264 \ - # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ - # --teleop.type=so100_leader \ - # --teleop.port=/dev/tty.usbmodem58760431551 \ - # --teleop.id=blue \ - # <- Policy optional if you want to record with a policy \ - # --policy.path=${HF_USER}/my_policy \ ``` Example recording with bimanual so100: ```shell -lerobot-record \ - --robot.type=bi_so_follower \ - --robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \ - --robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \ - --robot.id=bimanual_follower \ +lerobot-record \\ + --robot.type=bi_so_follower \\ + --robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \\ + --robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \\ + --robot.id=bimanual_follower \\ --robot.left_arm_config.cameras='{ wrist: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30}, top: {"type": "opencv", "index_or_path": 3, "width": 640, "height": 480, "fps": 30}, }' --robot.right_arm_config.cameras='{ wrist: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30}, front: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30}, - }' \ - --teleop.type=bi_so_leader \ - --teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \ - --teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \ - --teleop.id=bimanual_leader \ - --display_data=true \ - --dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \ - --dataset.num_episodes=25 \ - --dataset.single_task="Grab and handover the red cube to the other arm" \ - --dataset.streaming_encoding=true \ - # --dataset.vcodec=auto \ + }' \\ + --teleop.type=bi_so_leader \\ + --teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \\ + --teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \\ + --teleop.id=bimanual_leader \\ + --display_data=true \\ + --dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \\ + --dataset.num_episodes=25 \\ + --dataset.single_task="Grab and handover the red cube to the other arm" \\ + --dataset.streaming_encoding=true \\ --dataset.encoder_threads=2 ``` """ import logging import time -from dataclasses import asdict, dataclass, field -from pathlib import Path +from dataclasses import asdict, dataclass from pprint import pformat -from typing import Any - -import torch from lerobot.cameras import CameraConfig # noqa: F401 from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 @@ -86,11 +78,10 @@ from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 from lerobot.common.control_utils import ( init_keyboard_listener, is_headless, - predict_action, - sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, ) -from lerobot.configs import PreTrainedConfig, parser +from lerobot.configs import parser +from lerobot.configs.dataset import DatasetRecordConfig from lerobot.datasets import ( LeRobotDataset, VideoEncodingManager, @@ -98,21 +89,11 @@ from lerobot.datasets import ( create_initial_features, safe_stop_image_writer, ) -from lerobot.policies import ( - ActionInterpolator, - PreTrainedPolicy, - make_policy, - make_pre_post_processors, - make_robot_action, -) from lerobot.processor import ( - PolicyAction, - PolicyProcessorPipeline, RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors, - rename_stats, ) from lerobot.robots import ( # noqa: F401 Robot, @@ -146,7 +127,6 @@ from lerobot.teleoperators import ( # noqa: F401 ) from lerobot.teleoperators.keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep @@ -157,71 +137,12 @@ from lerobot.utils.utils import ( from lerobot.utils.visualization_utils import init_rerun, log_rerun_data -@dataclass -class DatasetRecordConfig: - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") - single_task: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. - root: str | Path | None = None - # Limit the frames per second. - fps: int = 30 - # Number of seconds for data recording for each episode. - episode_time_s: int | float = 60 - # Number of seconds for resetting the environment after each episode. - reset_time_s: int | float = 60 - # Number of episodes to record. - num_episodes: int = 50 - # Encode frames in the dataset into video - video: bool = True - # Upload dataset to Hugging Face hub. - push_to_hub: bool = True - # Upload on private repository on the Hugging Face hub. - private: bool = False - # Add tags to your dataset on the hub. - tags: list[str] | None = None - # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; - # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes - # and threads depends on your system. We recommend 4 threads per camera with 0 processes. - # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. - num_image_writer_processes: int = 0 - # Number of threads writing the frames as png images on disk, per camera. - # Too many threads might cause unstable teleoperation fps due to main thread being blocked. - # Not enough threads might cause low camera fps. - num_image_writer_threads_per_camera: int = 4 - # Number of episodes to record before batch encoding videos - # Set to 1 for immediate encoding (default behavior), or higher for batched encoding - video_encoding_batch_size: int = 1 - # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto', - # or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'. - # Use 'auto' to auto-detect the best available hardware encoder. - vcodec: str = "libsvtav1" - # Enable streaming video encoding: encode frames in real-time during capture instead - # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding - streaming_encoding: bool = False - # Maximum number of frames to buffer per camera when using streaming encoding. - # ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up. - encoder_queue_maxsize: int = 30 - # Number of threads per encoder instance. None = auto (codec default). - # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. - encoder_threads: int | None = None - # Rename map for the observation to override the image and state keys - rename_map: dict[str, str] = field(default_factory=dict) - - def __post_init__(self): - if self.single_task is None: - raise ValueError("You need to provide a task as argument in `single_task`.") - - @dataclass class RecordConfig: robot: RobotConfig dataset: DatasetRecordConfig - # Whether to control the robot with a teleoperator + # Teleoperator to control the robot (required) teleop: TeleoperatorConfig | None = None - # Whether to control the robot with a policy - policy: PreTrainedConfig | None = None # Display all cameras on screen display_data: bool = False # Display data on a remote Rerun server @@ -234,27 +155,14 @@ class RecordConfig: play_sounds: bool = True # Resume recording on an existing dataset. resume: bool = False - # Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x) - # Only applies when using a policy (not teleop) - interpolation_multiplier: int = 1 def __post_init__(self): - # HACK: We parse again the cli args here to get the pretrained path if there was one. - policy_path = parser.get_path_arg("policy") - - if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - - if self.teleop is None and self.policy is None: - raise ValueError("Choose a policy, a teleoperator or both to control the robot") - - @classmethod - def __get_path_fields__(cls) -> list[str]: - """This enables the parser to load config from the policy using `--policy.path=local/dir`""" - return ["policy"] + if self.teleop is None: + raise ValueError( + "A teleoperator is required for recording. " + "Use --teleop.type=... to specify one. " + "For policy-based deployment, use lerobot-rollout instead." + ) """ --------------- record_loop() data flow -------------------------- @@ -264,18 +172,14 @@ class RecordConfig: V [ robot_observation_processor ] ---> processed_obs V - .-----( ACTION LOGIC )------------------. - V V - [ From Teleoperator ] [ From Policy ] - | | - | [teleop.get_action] -> raw_action | [predict_action] - | | | | - | V | V - | [teleop_action_processor] | | - | | | | - '---> processed_teleop_action '---> processed_policy_action - | | - '-------------------------.-------------' + [ Teleoperator ] + | + | [teleop.get_action] -> raw_action + | | + | V + | [teleop_action_processor] + | | + '---> processed_teleop_action V [ robot_action_processor ] --> robot_action_to_send V @@ -303,13 +207,9 @@ def record_loop( ], # runs after robot dataset: LeRobotDataset | None = None, teleop: Teleoperator | list[Teleoperator] | None = None, - policy: PreTrainedPolicy | None = None, - preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None, - postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None, control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, - interpolator: ActionInterpolator | None = None, display_compressed_images: bool = False, ): if dataset is not None and dataset.fps != fps: @@ -340,21 +240,7 @@ def record_loop( "For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot." ) - # Reset policy and processor if they are provided - if policy is not None and preprocessor is not None and postprocessor is not None: - policy.reset() - preprocessor.reset() - postprocessor.reset() - - # Reset interpolator if provided - if interpolator is not None: - interpolator.reset() - - # Calculate control interval based on interpolation - use_interpolation = interpolator is not None and interpolator.enabled and policy is not None - control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps - # Pre-compute action key order outside the hot loop — it won't change mid-episode. - action_keys = sorted(robot.action_features) if use_interpolation else [] + control_interval = 1 / fps no_action_count = 0 timestamp = 0 @@ -372,63 +258,11 @@ def record_loop( # Applies a pipeline to the raw robot observation, default is IdentityProcessor obs_processed = robot_observation_processor(obs) - if policy is not None or dataset is not None: + if dataset is not None: observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) - # Track whether this iteration should be recorded to the dataset. - # Interpolated-only iterations send actions to the robot but don't record frames, - # keeping the dataset at the original fps while the robot moves at the higher rate. - is_record_frame = True - - # Get action from either policy or teleop - if policy is not None and preprocessor is not None and postprocessor is not None: - # With interpolation: only call policy when interpolator needs new action - if use_interpolation: - ran_inference = False - - if interpolator.needs_new_action(): - action_values = predict_action( - observation=observation_frame, - policy=policy, - device=get_safe_torch_device(policy.config.device), - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=single_task, - robot_type=robot.robot_type, - ) - act_processed_policy = make_robot_action(action_values, dataset.features) - robot_action_to_send = robot_action_processor((act_processed_policy, obs)) - - action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys]) - interpolator.add(action_tensor) - ran_inference = True - - interp_action = interpolator.get() - if interp_action is not None: - robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)} - action_values = robot_action_to_send - else: - continue - - is_record_frame = ran_inference - else: - action_values = predict_action( - observation=observation_frame, - policy=policy, - device=get_safe_torch_device(policy.config.device), - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=single_task, - robot_type=robot.robot_type, - ) - act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) - # Applies a pipeline to the action, default is IdentityProcessor - robot_action_to_send = robot_action_processor((act_processed_policy, obs)) - action_values = robot_action_to_send - - elif policy is None and isinstance(teleop, Teleoperator): + # Get action from teleop + if isinstance(teleop, Teleoperator): act = teleop.get_action() if robot.name == "unitree_g1": teleop.send_feedback(obs) @@ -438,7 +272,7 @@ def record_loop( action_values = act_processed_teleop robot_action_to_send = robot_action_processor((act_processed_teleop, obs)) - elif policy is None and isinstance(teleop, list): + elif isinstance(teleop, list): arm_action = teleop_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} keyboard_action = teleop_keyboard.get_action() @@ -451,7 +285,7 @@ def record_loop( no_action_count += 1 if no_action_count == 1 or no_action_count % 10 == 0: logging.warning( - "No policy or teleoperator provided, skipping action generation. " + "No teleoperator provided, skipping action generation. " "This is likely to happen when resetting the environment without a teleop device. " "The robot won't be at its rest position at the start of the next episode." ) @@ -463,8 +297,8 @@ def record_loop( # TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot. _sent_action = robot.send_action(robot_action_to_send) - # Write to dataset (only on real policy frames, not interpolated-only iterations) - if dataset is not None and is_record_frame: + # Write to dataset + if dataset is not None: action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) frame = {**observation_frame, **action_frame, "task": single_task} dataset.add_frame(frame) @@ -488,7 +322,12 @@ def record_loop( @parser.wrap() -def record(cfg: RecordConfig) -> LeRobotDataset: +def record( + cfg: RecordConfig, + teleop_action_processor: RobotProcessorPipeline | None = None, + robot_action_processor: RobotProcessorPipeline | None = None, + robot_observation_processor: RobotProcessorPipeline | None = None, +) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: @@ -502,7 +341,16 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None - teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + # Fall back to identity pipelines when the caller doesn't supply processors. + if ( + teleop_action_processor is None + or robot_action_processor is None + or robot_observation_processor is None + ): + _t, _r, _o = make_default_processors() + teleop_action_processor = teleop_action_processor or _t + robot_action_processor = robot_action_processor or _r + robot_observation_processor = robot_observation_processor or _o dataset_features = combine_feature_dicts( aggregate_pipeline_dataset_features( @@ -540,8 +388,14 @@ def record(cfg: RecordConfig) -> LeRobotDataset: ) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) else: - # Create empty dataset or load existing saved episodes - sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy) + # Reject eval_ prefix — for policy evaluation use lerobot-rollout + repo_name = cfg.dataset.repo_id.split("/", 1)[-1] + if repo_name.startswith("eval_"): + raise ValueError( + "Dataset names starting with 'eval_' are reserved for policy evaluation. " + "lerobot-record is for data collection only. Use lerobot-rollout for policy deployment." + ) + cfg.dataset.stamp_repo_id() dataset = LeRobotDataset.create( cfg.dataset.repo_id, cfg.dataset.fps, @@ -558,30 +412,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset: encoder_threads=cfg.dataset.encoder_threads, ) - # Load pretrained policy - policy = ( - None - if cfg.policy is None - else make_policy(cfg.policy, ds_meta=dataset.meta, rename_map=cfg.dataset.rename_map) - ) - preprocessor = None - postprocessor = None - interpolator = None - if cfg.policy is not None: - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.policy.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) - # Create interpolator for smoother policy control - if cfg.interpolation_multiplier > 1: - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) - logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate") - robot.connect() if teleop is not None: teleop.connect() @@ -605,14 +435,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot_action_processor=robot_action_processor, robot_observation_processor=robot_observation_processor, teleop=teleop, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, dataset=dataset, control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, - interpolator=interpolator, display_compressed_images=display_compressed_images, ) @@ -660,7 +486,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset: listener.stop() if cfg.dataset.push_to_hub: - dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + if dataset and dataset.num_episodes > 0: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + else: + logging.warning("No episodes saved — skipping push to hub") log_say("Exiting", cfg.play_sounds) return dataset diff --git a/src/lerobot/scripts/lerobot_rollout.py b/src/lerobot/scripts/lerobot_rollout.py new file mode 100644 index 000000000..6a81563ee --- /dev/null +++ b/src/lerobot/scripts/lerobot_rollout.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy deployment engine with pluggable rollout strategies. + +``lerobot-rollout`` is the single CLI for running trained policies on +real robots. + +Strategies +---------- + --strategy.type=base Autonomous rollout, no recording + --strategy.type=sentry Continuous recording with auto-upload + --strategy.type=highlight Ring buffer + keystroke save + --strategy.type=dagger Human-in-the-loop (DAgger / RaC) + +Inference backends +------------------ + --inference.type=sync One policy call per control tick (default) + --inference.type=rtc Real-Time Chunking for slow VLA models + +Usage examples +-------------- +:: + + # Base mode — quick evaluation with sync inference + lerobot-rollout \\ + --strategy.type=base \\ + --policy.path=lerobot/act_koch_real \\ + --robot.type=koch_follower \\ + --robot.port=/dev/ttyACM0 \\ + --task="pick up cube" --duration=30 + + # Base mode — RTC inference for slow VLAs (Pi0, Pi0.5, SmolVLA) + lerobot-rollout \\ + --strategy.type=base \\ + --policy.path=lerobot/pi0_base \\ + --inference.type=rtc \\ + --inference.rtc.execution_horizon=10 \\ + --inference.rtc.max_guidance_weight=10.0 \\ + --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\ + --task="pick up cube" --duration=60 + + # Sentry mode — continuous recording with periodic upload + lerobot-rollout \\ + --strategy.type=sentry \\ + --strategy.upload_every_n_episodes=5 \\ + --policy.path=lerobot/pi0_base \\ + --inference.type=rtc \\ + --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --dataset.repo_id=user/rollout_sentry_data \\ + --dataset.single_task="patrol" --duration=3600 + + # Highlight mode — ring buffer, press 's' to save, 'h' to push + lerobot-rollout \\ + --strategy.type=highlight \\ + --strategy.ring_buffer_seconds=30 \\ + --policy.path=lerobot/act_koch_real \\ + --robot.type=koch_follower \\ + --robot.port=/dev/ttyACM0 \\ + --dataset.repo_id=user/rollout_highlight_data \\ + --dataset.single_task="pick up cube" + + # DAgger mode — human-in-the-loop corrections only + lerobot-rollout \\ + --strategy.type=dagger \\ + --strategy.num_episodes=20 \\ + --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \\ + --robot.type=bi_openarm_follower \\ + --teleop.type=openarm_mini \\ + --dataset.repo_id=user/rollout_hil_data \\ + --dataset.single_task="Fold the T-shirt" + + # DAgger mode — continuous recording with RTC inference + lerobot-rollout \\ + --strategy.type=dagger \\ + --strategy.record_autonomous=true \\ + --strategy.num_episodes=50 \\ + --inference.type=rtc \\ + --inference.rtc.execution_horizon=10 \\ + --policy.path=user/my_pi0_policy \\ + --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --teleop.type=so101_leader \\ + --teleop.port=/dev/ttyACM1 \\ + --dataset.repo_id=user/rollout_dagger_rtc_data \\ + --dataset.single_task="Grasp the block" + + # With Rerun visualization and torch.compile + lerobot-rollout \\ + --strategy.type=base \\ + --policy.path=lerobot/act_koch_real \\ + --robot.type=koch_follower \\ + --robot.port=/dev/ttyACM0 \\ + --task="pick up cube" --duration=60 \\ + --display_data=true \\ + --use_torch_compile=true + + # Resume a previous sentry recording session + lerobot-rollout \\ + --strategy.type=sentry \\ + --policy.path=user/my_policy \\ + --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --dataset.repo_id=user/rollout_sentry_data \\ + --dataset.single_task="patrol" \\ + --resume=true +""" + +import logging + +from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.robots import ( # noqa: F401 + Robot, + RobotConfig, + bi_openarm_follower, + bi_so_follower, + earthrover_mini_plus, + hope_jr, + koch_follower, + omx_follower, + openarm_follower, + reachy2, + so_follower, + unitree_g1 as unitree_g1_robot, +) +from lerobot.rollout import RolloutConfig, build_rollout_context, create_strategy +from lerobot.teleoperators import ( # noqa: F401 + Teleoperator, + TeleoperatorConfig, + bi_openarm_leader, + bi_so_leader, + homunculus, + koch_leader, + omx_leader, + openarm_leader, + openarm_mini, + reachy2_teleoperator, + so_leader, + unitree_g1, +) +from lerobot.utils.import_utils import register_third_party_plugins +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging +from lerobot.utils.visualization_utils import init_rerun + +logger = logging.getLogger(__name__) + + +@parser.wrap() +def rollout(cfg: RolloutConfig): + """Main entry point for policy deployment.""" + init_logging() + + if cfg.display_data: + logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port) + init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port) + + signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False) + shutdown_event = signal_handler.shutdown_event + + logger.info("Building rollout context...") + ctx = build_rollout_context(cfg, shutdown_event) + + strategy = create_strategy(cfg.strategy) + logger.info("Rollout strategy: %s", cfg.strategy.type) + logger.info( + "Robot: %s | FPS: %.0f | Duration: %s", + cfg.robot.type if cfg.robot else "?", + cfg.fps, + f"{cfg.duration}s" if cfg.duration > 0 else "infinite", + ) + + try: + strategy.setup(ctx) + logger.info("Rollout setup complete, starting rollout...") + strategy.run(ctx) + except KeyboardInterrupt: + logger.info("Interrupted by user") + finally: + strategy.teardown(ctx) + + logger.info("Rollout finished") + + +def main(): + """CLI entry point for ``lerobot-rollout``.""" + register_third_party_plugins() + rollout() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/action_interpolator.py b/src/lerobot/utils/action_interpolator.py new file mode 100644 index 000000000..222dc33b5 --- /dev/null +++ b/src/lerobot/utils/action_interpolator.py @@ -0,0 +1,116 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Action interpolation for smoother robot control. + +Provides configurable Nx control rate by interpolating between consecutive actions. +Useful with RTC and action-chunking policies to reduce jerkiness. +""" + +from torch import Tensor + + +class ActionInterpolator: + """Interpolates between consecutive actions for smoother control. + + When enabled with multiplier N, produces N actions per policy action + by linearly interpolating between the previous and current action. + + Example with multiplier=3: + prev_action -> [1/3 interpolated, 2/3 interpolated, current_action] + + This effectively multiplies the control rate for smoother motion. + + Usage: + interpolator = ActionInterpolator(multiplier=2) # 2x control rate + + # In control loop: + if interpolator.needs_new_action(): + new_action = queue.get() + if new_action: + interpolator.add(new_action.cpu()) + + action = interpolator.get() + if action: + robot.send_action(action) + """ + + def __init__(self, multiplier: int = 1): + """Initialize the interpolator. + + Args: + multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.) + """ + if multiplier < 1: + raise ValueError(f"multiplier must be >= 1, got {multiplier}") + self.multiplier = multiplier + self._prev: Tensor | None = None + self._buffer: list[Tensor] = [] + self._idx = 0 + + @property + def enabled(self) -> bool: + """Whether interpolation is active (multiplier > 1).""" + return self.multiplier > 1 + + def reset(self): + """Reset interpolation state (call between episodes).""" + self._prev = None + self._buffer = [] + self._idx = 0 + + def needs_new_action(self) -> bool: + """Check if a new action is needed from the queue.""" + return self._idx >= len(self._buffer) + + def add(self, action: Tensor) -> None: + """Add a new action and compute interpolated sequence. + + Args: + action: New action tensor from policy/queue (already on CPU). + """ + if self.multiplier > 1 and self._prev is not None: + self._buffer = [] + for i in range(1, self.multiplier + 1): + t = i / self.multiplier + interp = self._prev + t * (action - self._prev) + self._buffer.append(interp) + else: + # First step: no previous action yet, so run at base FPS without interpolation. + self._buffer = [action.clone()] + self._prev = action.clone() + self._idx = 0 + + def get(self) -> Tensor | None: + """Get the next interpolated action. + + Returns: + Next action tensor, or None if buffer is exhausted. + """ + if self._idx >= len(self._buffer): + return None + action = self._buffer[self._idx] + self._idx += 1 + return action + + def get_control_interval(self, fps: float) -> float: + """Get the control interval based on interpolation multiplier. + + Args: + fps: Base frames per second. + + Returns: + Control interval in seconds (divided by multiplier). + """ + return 1.0 / (fps * self.multiplier) diff --git a/src/lerobot/utils/pedal.py b/src/lerobot/utils/pedal.py new file mode 100644 index 000000000..88f3db1f9 --- /dev/null +++ b/src/lerobot/utils/pedal.py @@ -0,0 +1,83 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic foot pedal listener using evdev. + +Callers supply a callback receiving the pressed key code (e.g. ``"KEY_A"``) +and an optional device path. The listener runs in a daemon thread and +silently no-ops when :mod:`evdev` is not installed or the device is +unavailable. Strategy-specific key mapping logic lives in the caller. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Callable + +logger = logging.getLogger(__name__) + +DEFAULT_PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + + +def start_pedal_listener( + on_press: Callable[[str], None], + device_path: str = DEFAULT_PEDAL_DEVICE, +) -> threading.Thread | None: + """Spawn a daemon thread that forwards pedal key-press codes to ``on_press``. + + Parameters + ---------- + on_press: + Callback invoked with the pressed key code string (e.g. ``"KEY_A"``) + on each pedal press event. The callback runs in the listener thread + and must be thread-safe. + device_path: + Linux input device path (e.g. ``/dev/input/by-id/...``). + + Returns + ------- + The started daemon :class:`threading.Thread`, or ``None`` when + :mod:`evdev` is not installed (optional dependency; silent no-op). + """ + try: + from evdev import InputDevice, categorize, ecodes + except ImportError: + return None + + def pedal_reader() -> None: + try: + dev = InputDevice(device_path) + logger.info("Pedal connected: %s", dev.name) + for ev in dev.read_loop(): + if ev.type != ecodes.EV_KEY: + continue + key = categorize(ev) + code = key.keycode + if isinstance(code, (list, tuple)): + code = code[0] + if key.keystate != 1: # only key-down events + continue + try: + on_press(code) + except Exception as cb_err: # pragma: no cover - defensive + logger.warning("Pedal callback error: %s", cb_err) + except (FileNotFoundError, PermissionError): + pass + except Exception as e: + logger.warning("Pedal error: %s", e) + + thread = threading.Thread(target=pedal_reader, daemon=True, name="PedalListener") + thread.start() + return thread diff --git a/src/lerobot/rl/process.py b/src/lerobot/utils/process.py similarity index 100% rename from src/lerobot/rl/process.py rename to src/lerobot/utils/process.py diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index 49efa84d9..26406dea2 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -416,6 +416,18 @@ def test_create_initial_counts_zero(tmp_path): assert dataset.num_frames == 0 +def test_create_propagates_video_files_size_in_mb(tmp_path): + """video_files_size_in_mb passed to create() is reflected in the dataset metadata.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + root=tmp_path / "ds", + video_files_size_in_mb=42.0, + ) + assert dataset.meta.video_files_size_in_mb == 42.0 + + def test_add_frame_works_in_write_mode(tmp_path): """add_frame() succeeds on a dataset created via create().""" dataset = LeRobotDataset.create( diff --git a/tests/policies/rtc/test_action_interpolator.py b/tests/policies/rtc/test_action_interpolator.py index 9a4276df1..3eb239d7e 100644 --- a/tests/policies/rtc/test_action_interpolator.py +++ b/tests/policies/rtc/test_action_interpolator.py @@ -17,9 +17,9 @@ import pytest import torch -from lerobot.policies.rtc.action_interpolator import ActionInterpolator from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.action_interpolator import ActionInterpolator # ====================== Fixtures ====================== diff --git a/tests/policies/rtc/test_configuration_rtc.py b/tests/policies/rtc/test_configuration_rtc.py index bb4550eaa..40d171c0c 100644 --- a/tests/policies/rtc/test_configuration_rtc.py +++ b/tests/policies/rtc/test_configuration_rtc.py @@ -26,7 +26,7 @@ def test_rtc_config_default_initialization(): """Test RTCConfig initializes with default values.""" config = RTCConfig() - assert config.enabled is False + assert config.enabled is True assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR assert config.max_guidance_weight == 10.0 assert config.execution_horizon == 10 diff --git a/tests/policies/rtc/test_rtc_relative_actions.py b/tests/policies/rtc/test_rtc_relative_actions.py index fa888ec05..66667ea56 100644 --- a/tests/policies/rtc/test_rtc_relative_actions.py +++ b/tests/policies/rtc/test_rtc_relative_actions.py @@ -22,7 +22,7 @@ from lerobot.configs.types import ( PolicyFeature, RTCAttentionSchedule, ) -from lerobot.processor import TransitionKey, batch_to_transition +from lerobot.processor import TransitionKey, batch_to_transition, create_transition from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep from lerobot.processor.relative_action_processor import ( AbsoluteActionsProcessorStep, @@ -52,6 +52,9 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug _rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py") RTCProcessor = _rtc_mod.RTCProcessor +_rtc_relative_mod = _import_rtc_module("lerobot.policies.rtc.relative", "relative.py") +reanchor_relative_rtc_prefix = _rtc_relative_mod.reanchor_relative_rtc_prefix + ACTION_DIM = 6 CHUNK_SIZE = 50 EXECUTION_HORIZON = 10 @@ -187,7 +190,7 @@ class TestRTCDenoiseWithRelativeLeftovers: class TestFullPipelineRelativeRTC: - """End-to-end test of the RTC + relative actions pipeline matching eval_with_real_robot.py flow.""" + """End-to-end test of the RTC + relative actions pipeline matching lerobot-rollout flow.""" def test_preprocessor_caches_state_for_postprocessor(self): """Preprocessor's relative step should cache state so postprocessor can convert back.""" @@ -218,7 +221,9 @@ class TestFullPipelineRelativeRTC: def test_roundtrip_with_identity_normalization(self): """Actions → relative → normalize → [model] → unnormalize → absolute should recover originals. - Using mean=0, std=1 normalization (identity).""" + + Using mean=0, std=1 normalization (identity). + """ relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline() state = torch.randn(1, ACTION_DIM) @@ -240,7 +245,7 @@ class TestFullPipelineRelativeRTC: torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5) def test_eval_loop_simulation(self): - """Simulate the eval_with_real_robot.py loop with relative actions. + """Simulate the lerobot-rollout loop with relative actions. Iteration 1: No leftovers → model generates relative actions → store for RTC Iteration 2: Use leftovers as RTC guidance → model generates new relative actions @@ -400,13 +405,113 @@ class TestStateRebasingApproximation: assert error_excluded < 1e-6, f"Excluded joint should have zero error, got {error_excluded}" +class TestRTCReanchoringWithStateNormalizer: + """RTC re-anchoring under non-identity OBS_STATE normalization.""" + + @staticmethod + def _build_normalizer_with_state_stats(): + """Build a relative-action preprocessor with non-trivial OBS_STATE stats.""" + features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(ACTION_DIM,)), + } + norm_map = { + FeatureType.ACTION: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MEAN_STD, + } + stats = { + ACTION: { + "mean": torch.zeros(ACTION_DIM).numpy(), + "std": (0.5 * torch.ones(ACTION_DIM)).numpy(), + }, + OBS_STATE: { + "mean": (5.0 * torch.ones(ACTION_DIM)).numpy(), + "std": (2.0 * torch.ones(ACTION_DIM)).numpy(), + }, + } + relative_step = RelativeActionsProcessorStep(enabled=True) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + return relative_step, normalizer + + def test_reanchor_with_raw_state_matches_normalize_of_absolute_minus_state(self): + """Reanchoring with the raw cached state yields ``normalize(prev_actions_absolute - raw_state)``.""" + relative_step, normalizer = self._build_normalizer_with_state_stats() + + raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]) + relative_step(batch_to_transition({OBS_STATE: raw_state.clone()})) + + prev_actions_absolute = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]] * 5) + + result = reanchor_relative_rtc_prefix( + prev_actions_absolute=prev_actions_absolute, + current_state=relative_step.get_cached_state(), + relative_step=relative_step, + normalizer_step=normalizer, + policy_device="cpu", + ) + + expected_relative = to_relative_actions(prev_actions_absolute, raw_state, [True] * ACTION_DIM) + expected = normalizer(create_transition(action=expected_relative))[TransitionKey.ACTION] + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + def test_reanchor_with_normalized_state_produces_wrong_result(self): + """Reanchoring with raw vs. normalized state produces meaningfully different outputs.""" + relative_step, normalizer = self._build_normalizer_with_state_stats() + + raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]) + relative_step(batch_to_transition({OBS_STATE: raw_state.clone()})) + + normalized_obs = normalizer(batch_to_transition({OBS_STATE: raw_state.clone()})) + normalized_state = normalized_obs[TransitionKey.OBSERVATION][OBS_STATE] + assert not torch.allclose(normalized_state, raw_state) + + prev_actions_absolute = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]] * 5) + + result_raw = reanchor_relative_rtc_prefix( + prev_actions_absolute=prev_actions_absolute, + current_state=raw_state, + relative_step=relative_step, + normalizer_step=normalizer, + policy_device="cpu", + ) + result_normalized = reanchor_relative_rtc_prefix( + prev_actions_absolute=prev_actions_absolute, + current_state=normalized_state, + relative_step=relative_step, + normalizer_step=normalizer, + policy_device="cpu", + ) + + max_abs_diff = (result_raw - result_normalized).abs().max() + assert max_abs_diff > 0.5, ( + f"Raw and normalized state produced near-identical outputs (max diff {max_abs_diff:.4f}); " + "OBS_STATE stats are too close to identity to be sensitive." + ) + + def test_engine_pipeline_cached_state_is_raw_after_full_preprocess(self): + """``get_cached_state()`` returns raw OBS_STATE after the full preprocessor pipeline runs.""" + relative_step, normalizer = self._build_normalizer_with_state_stats() + + raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]) + + transition = batch_to_transition({OBS_STATE: raw_state.clone()}) + transition = relative_step(transition) + preprocessed = normalizer(transition) + + cached = relative_step.get_cached_state() + torch.testing.assert_close(cached, raw_state, atol=1e-6, rtol=1e-6) + + post_normalize_state = preprocessed[TransitionKey.OBSERVATION][OBS_STATE] + assert not torch.allclose(cached, post_normalize_state, atol=1e-3) + + def _detect_relative_actions(preprocessor) -> bool: - """Mirror of the helper in eval_with_real_robot.py for testing without importing it.""" + """Mirror of the helper in lerobot-rollout for testing without importing it.""" return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps) class TestDetectRelativeActions: - """Test the _detect_relative_actions helper logic used by eval_with_real_robot.py.""" + """Test the _detect_relative_actions helper logic used by lerobot-rollout.""" def test_detects_enabled_relative_step(self): class FakePipeline: diff --git a/tests/test_cli_peft.py b/tests/test_cli_peft.py index 5d653ee6b..82f41affa 100644 --- a/tests/test_cli_peft.py +++ b/tests/test_cli_peft.py @@ -24,10 +24,6 @@ def lerobot_train(args): return run_command(cmd="lerobot-train", module="lerobot_train", args=args) -def lerobot_record(args): - return run_command(cmd="lerobot-record", module="lerobot_record", args=args) - - def resolve_model_id_for_peft_training(policy_type): """PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training.""" if policy_type == "smolvla": @@ -155,81 +151,3 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path): f"--output_dir={output_dir}", ] ) - - -class DummyRobot: - name = "dummy" - cameras = [] - action_features = {"foo": 1.0, "bar": 2.0} - observation_features = {"obs1": 1.0, "obs2": 2.0} - is_connected = True - - def connect(self, *args): - pass - - def disconnect(self): - pass - - -def dummy_make_robot_from_config(*args, **kwargs): - return DummyRobot() - - -@pytest.mark.parametrize("policy_type", ["smolvla"]) -@skip_if_package_missing("peft") -def test_peft_record_loads_policy(policy_type, tmp_path): - """Train a policy with PEFT and attempt to load it with `lerobot-record`.""" - from peft import PeftModel - - output_dir = tmp_path / f"output_{policy_type}" - model_id = resolve_model_id_for_peft_training(policy_type) - - lerobot_train( - [ - f"--policy.path={model_id}", - "--policy.push_to_hub=false", - "--policy.input_features=null", - "--policy.output_features=null", - "--peft.method=LORA", - "--dataset.repo_id=lerobot/pusht", - "--dataset.episodes=[0, 1]", - "--steps=1", - f"--output_dir={output_dir}", - ] - ) - - policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model" - dataset_dir = tmp_path / "eval_pusht" - single_task = "move the table" - loaded_policy = None - - def dummy_record_loop(*args, **kwargs): - nonlocal loaded_policy - - if "dataset" not in kwargs: - return - - dataset = kwargs["dataset"] - dataset.add_frame({"task": single_task}) - loaded_policy = kwargs["policy"] - - with ( - patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config), - # disable record loop since we're only interested in successful loading of the policy. - patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop), - # disable speech output - patch("lerobot.utils.utils.say"), - ): - lerobot_record( - [ - f"--policy.path={policy_dir}", - "--robot.type=so101_follower", - "--robot.port=/dev/null", - "--dataset.repo_id=lerobot/eval_pusht", - f'--dataset.single_task="{single_task}"', - f"--dataset.root={dataset_dir}", - "--dataset.push_to_hub=false", - ] - ) - - assert isinstance(loaded_policy, PeftModel) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 28e91a149..dd10c0c1c 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -21,8 +21,9 @@ import pytest pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])") +from lerobot.configs.dataset import DatasetRecordConfig from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate -from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record +from lerobot.scripts.lerobot_record import RecordConfig, record from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate from tests.fixtures.constants import DUMMY_REPO_ID diff --git a/tests/test_rollout.py b/tests/test_rollout.py new file mode 100644 index 000000000..5a1ec4703 --- /dev/null +++ b/tests/test_rollout.py @@ -0,0 +1,345 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal tests for the rollout module's public API.""" + +from __future__ import annotations + +import dataclasses +from unittest.mock import MagicMock + +import pytest +import torch + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +# --------------------------------------------------------------------------- +# Import smoke tests +# --------------------------------------------------------------------------- + + +def test_rollout_top_level_imports(): + import lerobot.rollout + + for name in lerobot.rollout.__all__: + assert hasattr(lerobot.rollout, name), f"Missing export: {name}" + + +def test_inference_submodule_imports(): + import lerobot.rollout.inference + + for name in lerobot.rollout.inference.__all__: + assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}" + + +def test_strategies_submodule_imports(): + import lerobot.rollout.strategies + + for name in lerobot.rollout.strategies.__all__: + assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}" + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +def test_strategy_config_types(): + from lerobot.rollout import ( + BaseStrategyConfig, + DAggerStrategyConfig, + HighlightStrategyConfig, + SentryStrategyConfig, + ) + + assert BaseStrategyConfig().type == "base" + assert SentryStrategyConfig().type == "sentry" + assert HighlightStrategyConfig().type == "highlight" + assert DAggerStrategyConfig().type == "dagger" + + +def test_dagger_config_invalid_input_device(): + from lerobot.rollout import DAggerStrategyConfig + + with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"): + DAggerStrategyConfig(input_device="joystick") + + +def test_dagger_config_defaults(): + from lerobot.rollout import DAggerStrategyConfig + + cfg = DAggerStrategyConfig() + assert cfg.num_episodes is None + assert cfg.record_autonomous is False + assert cfg.input_device == "keyboard" + + +def test_inference_config_types(): + from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig + + assert SyncInferenceConfig().type == "sync" + + rtc = RTCInferenceConfig() + assert rtc.type == "rtc" + assert rtc.queue_threshold == 30 + assert rtc.rtc is not None + + +def test_sentry_config_defaults(): + from lerobot.rollout import SentryStrategyConfig + + cfg = SentryStrategyConfig() + assert cfg.upload_every_n_episodes == 5 + assert cfg.target_video_file_size_mb is None + + +# --------------------------------------------------------------------------- +# RolloutRingBuffer +# --------------------------------------------------------------------------- + + +def test_ring_buffer_append_and_eviction(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0) + # max_frames = 5 + for i in range(8): + buf.append({"val": i}) + assert len(buf) == 5 + + +def test_ring_buffer_drain(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) + for i in range(3): + buf.append({"val": i}) + frames = buf.drain() + assert len(frames) == 3 + assert len(buf) == 0 + assert buf.estimated_bytes == 0 + + +def test_ring_buffer_clear(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) + buf.append({"val": 1}) + buf.clear() + assert len(buf) == 0 + assert buf.estimated_bytes == 0 + + +def test_ring_buffer_tensor_bytes(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) + t = torch.zeros(100, dtype=torch.float32) # 400 bytes + buf.append({"tensor": t}) + assert buf.estimated_bytes >= 400 + + +# --------------------------------------------------------------------------- +# ThreadSafeRobot +# --------------------------------------------------------------------------- + + +def test_thread_safe_robot_delegates(): + from lerobot.rollout.robot_wrapper import ThreadSafeRobot + from tests.mocks.mock_robot import MockRobot, MockRobotConfig + + robot = MockRobot(MockRobotConfig(n_motors=3)) + robot.connect() + wrapper = ThreadSafeRobot(robot) + + obs = wrapper.get_observation() + assert "motor_1.pos" in obs + assert "motor_2.pos" in obs + assert "motor_3.pos" in obs + + action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0} + result = wrapper.send_action(action) + assert result == action + + robot.disconnect() + + +def test_thread_safe_robot_properties(): + from lerobot.rollout.robot_wrapper import ThreadSafeRobot + from tests.mocks.mock_robot import MockRobot, MockRobotConfig + + robot = MockRobot(MockRobotConfig(n_motors=3)) + robot.connect() + wrapper = ThreadSafeRobot(robot) + + assert wrapper.name == "mock_robot" + assert "motor_1.pos" in wrapper.observation_features + assert "motor_1.pos" in wrapper.action_features + assert wrapper.is_connected is True + assert wrapper.inner is robot + + robot.disconnect() + + +# --------------------------------------------------------------------------- +# Strategy factory +# --------------------------------------------------------------------------- + + +def test_create_strategy_dispatches(): + from lerobot.rollout import ( + BaseStrategy, + BaseStrategyConfig, + DAggerStrategy, + DAggerStrategyConfig, + SentryStrategy, + SentryStrategyConfig, + create_strategy, + ) + + assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy) + assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy) + assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy) + + +def test_create_strategy_unknown_raises(): + from lerobot.rollout import create_strategy + + cfg = MagicMock() + cfg.type = "bogus" + with pytest.raises(ValueError, match="Unknown strategy type"): + create_strategy(cfg) + + +# --------------------------------------------------------------------------- +# Inference factory +# --------------------------------------------------------------------------- + + +def test_create_inference_engine_sync(): + from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine + + engine = create_inference_engine( + SyncInferenceConfig(), + policy=MagicMock(), + preprocessor=MagicMock(), + postprocessor=MagicMock(), + robot_wrapper=MagicMock(robot_type="mock"), + hw_features={}, + dataset_features={}, + ordered_action_keys=["k"], + task="test", + fps=30.0, + device="cpu", + ) + assert isinstance(engine, SyncInferenceEngine) + + +# --------------------------------------------------------------------------- +# Pure functions +# --------------------------------------------------------------------------- + + +def test_estimate_max_episode_seconds_no_video(): + from lerobot.rollout.strategies import estimate_max_episode_seconds + + assert estimate_max_episode_seconds({}, fps=30.0) == 300.0 + + +def test_estimate_max_episode_seconds_with_video(): + from lerobot.rollout.strategies import estimate_max_episode_seconds + + features = {"cam": {"dtype": "video", "shape": (480, 640, 3)}} + result = estimate_max_episode_seconds(features, fps=30.0) + assert result > 0 + # With a real camera, duration should differ from the fallback + assert result != 300.0 + + +def test_safe_push_to_hub(): + from lerobot.rollout.strategies import safe_push_to_hub + + ds = MagicMock() + ds.num_episodes = 0 + assert safe_push_to_hub(ds) is False + ds.push_to_hub.assert_not_called() + + ds.num_episodes = 5 + assert safe_push_to_hub(ds, tags=["test"]) is True + ds.push_to_hub.assert_called_once_with(tags=["test"], private=False) + + +# --------------------------------------------------------------------------- +# DAgger state machine +# --------------------------------------------------------------------------- + + +def test_dagger_full_transition_cycle(): + from lerobot.rollout.strategies import DAggerEvents, DAggerPhase + + events = DAggerEvents() + assert events.phase == DAggerPhase.AUTONOMOUS + + # AUTONOMOUS -> PAUSED + events.request_transition("pause_resume") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED) + + # PAUSED -> CORRECTING + events.request_transition("correction") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING) + + # CORRECTING -> PAUSED + events.request_transition("correction") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED) + + # PAUSED -> AUTONOMOUS + events.request_transition("pause_resume") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS) + + +def test_dagger_invalid_transition_ignored(): + from lerobot.rollout.strategies import DAggerEvents, DAggerPhase + + events = DAggerEvents() + events.request_transition("correction") # Not valid from AUTONOMOUS + assert events.consume_transition() is None + assert events.phase == DAggerPhase.AUTONOMOUS + + +def test_dagger_events_reset(): + from lerobot.rollout.strategies import DAggerEvents, DAggerPhase + + events = DAggerEvents() + events.request_transition("pause_resume") + events.consume_transition() # -> PAUSED + events.upload_requested.set() + events.reset() + assert events.phase == DAggerPhase.AUTONOMOUS + assert not events.upload_requested.is_set() + + +# --------------------------------------------------------------------------- +# Context dataclass +# --------------------------------------------------------------------------- + + +def test_rollout_context_fields(): + from lerobot.rollout import RolloutContext + + field_names = {f.name for f in dataclasses.fields(RolloutContext)} + assert field_names == {"runtime", "hardware", "policy", "processors", "data"} diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index ce56db173..65b24aac4 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -24,7 +24,7 @@ import pytest pytest.importorskip("grpc") -from lerobot.rl.process import ProcessSignalHandler # noqa: E402 +from lerobot.utils.process import ProcessSignalHandler # noqa: E402 # Fixture to reset shutdown_event_counter and original signal handlers before and after each test