Compare commits

...

40 Commits

Author SHA1 Message Date
fracapuano 1c8f922379 fix: minor things on the aggregation job 2025-11-21 09:30:39 +00:00
fracapuano 2b2ff19366 fix the number of workers to prevent contention 2025-11-21 09:30:39 +00:00
fracapuano c912b1dd03 fix: upload with multiple workers 2025-11-21 09:30:38 +00:00
fracapuano ca1841f5fc add: aggregation util 2025-11-21 09:30:38 +00:00
fracapuano f6755dbf20 add: utils for stabler, large scale upload (ds.push_to_hub may fail) 2025-11-21 09:30:38 +00:00
fracapuano 0846b5704c fix: resources trim 2025-11-21 09:30:38 +00:00
fracapuano f386591be7 fix: jobs for conversion and aggregation 2025-11-21 09:30:38 +00:00
fracapuano f875566e1d add: downloading data utils 2025-11-21 09:30:37 +00:00
fracapuano eaea3806e8 add: util to download behavior data 2025-11-21 09:30:37 +00:00
fracapuano 1ef0f0bb86 remove: unused constants file 2025-11-21 09:30:37 +00:00
fracapuano e70dd620f3 add: final aggregation utils to obtain one dataset only 2025-11-21 09:30:37 +00:00
fracapuano 31274975f0 fix: minor checks 2025-11-21 09:30:37 +00:00
fracapuano edbfa3d3e6 fix: slurm job for parallel conversion on nodes 2025-11-21 09:30:36 +00:00
fracapuano 09e2a55901 fix: add upload to hub option 2025-11-21 09:30:36 +00:00
fracapuano 413c5e01be fix: implement actual conversion for lerobotdataset-v3 compatibility 2025-11-21 09:30:36 +00:00
fracapuano 91a0a4fe7a add: slurm conversion script 2025-11-21 09:30:36 +00:00
fracapuano 7710411d3a remove: unused, useless bespoke dataset format 2025-11-21 09:30:36 +00:00
fracapuano 4a153825ee fix: minor 2025-11-21 09:30:36 +00:00
fracapuano 46606359fc fix: metadata stores the saved 0-based episode index 2025-11-21 09:30:35 +00:00
fracapuano 1d0eb922bd fix: episode index is asserted 0-based in lerobot dataset 2025-11-21 09:30:35 +00:00
fracapuano 1612aa7ac7 fix bug: correctly specify paths 2025-11-21 09:30:35 +00:00
Francesco Capuano c1f5d8f48f fix: add frame idx 2025-11-21 09:30:35 +00:00
Michel Aractingi 14743b896e * refactor behaviour1k_lerobot_dataset.py
* add example scripts to load behaviour 1k data in `load_behaviour1k_dataset.py`
2025-11-21 09:30:35 +00:00
Jade Choghari 624939c71c remove tester 2025-11-21 09:30:34 +00:00
Jade Choghari a276f5b8ac fix style 2025-11-21 09:30:34 +00:00
Jade Choghari 33ff386dbc remove comments 2025-11-21 09:30:34 +00:00
Jade Choghari 50f8cbc392 update changes 2025-11-21 09:30:34 +00:00
Jade Choghari 23999ba40d update
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-11-21 09:30:34 +00:00
Jade Choghari dd4837f06e add
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-11-21 09:30:34 +00:00
Michel Aractingi 9f00d2c3a2 Modify convert_to_lerobot_v3 script for behaviours dataset to take a single task id and create a dataset outof it 2025-11-21 09:30:33 +00:00
Michel Aractingi 950a6fb83d add scripts for convert behavior-1k to datasetv3 2025-11-21 09:30:33 +00:00
Michel Aractingi 0f551df8f4 add absolute_to_reative_idx for remapping indicies when a subset of data is loaded (#2490) 2025-11-20 14:05:31 +01:00
Jade Choghari 6e86a69dcd feat(envs): add envs pre-post processor (#2474)
* more changes

* working changes

* more changes

* more fixes

* fix style

* more

* clean

* put axis-1

* more fixes

* more styling fixes:

* iterate on review:

* more changes

* add env processor

* style

* more changes

* add docs

* fix imports

* fix test, add to train

* Update src/lerobot/envs/factory.py

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* iterate on review

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: jade.choghari@huggingface.co <“chogharijade@gmail.com”>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-19 18:36:14 +01:00
Eugene Mironov 8a915c6b6f [RTC] Real Time Chunking for Pi0, Smolvla, Pi0.5 (#1698)
* Add Real-Time Chunking (RTC) support for flow matching models

Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Fix rtc_config attribute access in SmolVLA

Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix rtc_config attribute access in SmolVLA

* Add RTCConfig field to SmolVLAConfig

Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Refactor RTC enabled checks to use _rtc_enabled helper

Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Rename track_debug method to track

Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Use output_dir for saving all evaluation images

Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Use output_dir for saving all evaluation images

* Fix logging buffering and enable tracking when RTC config provided

- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor SmolVLA plotting to use tracker data instead of local variables

Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Move plotting logic from modeling_smolvla to eval_dataset script

Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor plotting loging

* fixup! Refactor plotting loging

* Improve visualization: separate correction plot and fix axis scaling

Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* Fix traacking

* Right kwargs for the policy

* Add tests for tracker

* Fix tests

* Drop not required methods

* Add torch compilation for eval_dataset

* delete policies

* Add matplotliv to dev

* fixup! Add matplotliv to dev

* Experiemnt with late detach

* Debug

* Fix compilation

* Add RTC to PI0

* Pi0

* Pi0 eval dataset

* fixup! Pi0 eval dataset

* Turn off compilation for pi0/pi05

* fixup! Turn off compilation for pi0/pi05

* fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* Add workable flow

* Small fixes

* Add more tests

* Add validatio at the end

* Update README

* Silent validation

* Fix tests

* Add tests for modeling_rtc

* Add tests for flow matching models with RTC

* fixup! Add tests for flow matching models with RTC

* fixup! fixup! Add tests for flow matching models with RTC

* Add one more test

* fixup! Add one more test

* Fix test to use _rtc_enabled() instead of is_rtc_enabled()

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* Add RTC initialization tests without config for PI0.5 and SmolVLA

Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix SmolVLA init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

* Fixup eval with real robot

* fixup! Fixup eval with real robot

* fixup! fixup! Fixup eval with real robot

* Extract simulator logic from eval_with real robot and add proper headers to files

* Update images

* Fix tests

* fixup! Fix tests

* add docs for rtc

* enhance doc and add images

* Fix instal instructions

---------
Co-authored-by: Ben Zhang <benzhangniu@gmail.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-19 11:19:48 +01:00
Michel Aractingi b464d9f8bc Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456)
* filter episodes in load_nested_dataset

* nit

* remove test filtering

* move import to module level

* added missing episode indices to the EpisodeAwareSampler in lerobot_train.py;
2025-11-18 17:26:41 +01:00
Michel Aractingi 784cdae55a Fixes in port droid scripts (#2455)
* Fixes in port droid scripts

* revert default mem-per-cpu

* style nit

* fix relative imports

* style nit
2025-11-17 23:42:30 +01:00
Steven Palma d9e74a9d37 chore(dependencies): Bump lerobot to 0.4.2 (#2423) 2025-11-12 13:13:57 +01:00
Steven Palma a5b29d4301 chore(installation): remove libero installation patch (#2416)
* chore(installation): remove libero installation patch

* fix(ci): exclude groot for unbound deps test
2025-11-10 11:51:52 +01:00
Steven Palma a4aa316470 fix(dataset): fix data access bottleneck for faster training (#2408) 2025-11-07 21:54:44 +01:00
Michel Aractingi f6b16f6d97 fix(dataset_tools) Critical bug in modify features (#2342)
* fix bug in `_copy_data_with_feature_changes`

* Update src/lerobot/datasets/dataset_tools.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* add missing import

---------

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2025-11-04 15:56:41 +01:00
56 changed files with 8533 additions and 141 deletions
+3 -3
View File
@@ -83,11 +83,11 @@ jobs:
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
+1 -1
View File
@@ -70,7 +70,7 @@ jobs:
echo "Dependencies unbound:" && cat pyproject.toml
- name: Install lerobot with all extras
run: uv sync --all-extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv
+1 -1
View File
@@ -186,7 +186,7 @@ For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
+8 -2
View File
@@ -15,8 +15,6 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials"
@@ -40,6 +38,12 @@
- local: groot
title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
- local: async
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
title: "Inference"
- sections:
- local: envhub
title: Environments from the Hub
@@ -59,6 +63,8 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101
+418
View File
@@ -0,0 +1,418 @@
# Environment Processors
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
## Why Environment Processors?
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
## The Processing Pipeline
Here's how data flows through the complete processing pipeline during evaluation:
```python
# In lerobot_eval.py rollout() function:
# 1. Raw environment observation (numpy arrays, various formats)
raw_observation = env.step(action)
# 2. Convert numpy to torch, normalize images [0,1]
observation = preprocess_observation(raw_observation)
# 3. Add task metadata (for multi-task environments)
observation = add_envs_task(env, observation)
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
# - Flatten robot states
# - Rotate images to match dataset conventions
# - Handle environment-specific coordinate systems
observation = env_preprocessor(observation)
# 5. POLICY-SPECIFIC preprocessing
# - Normalize with dataset statistics
# - Add batch dimensions
# - Move to GPU
# - Tokenize language instructions
observation = preprocessor(observation)
# 6. Policy inference
action = policy.select_action(observation)
# 7. POLICY-SPECIFIC postprocessing
# - Unnormalize actions
# - Remove batch dimensions
action = postprocessor(action)
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
# - Convert action formats if needed
# - Apply environment-specific constraints
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# 9. Execute in environment
env.step(action)
```
## The Benefits
### 1. **Separation of Concerns**
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
```python
# ❌ Before: Mixed concerns
class LiberoVLAPolicy:
def preprocess(self, obs):
# Environment-specific: Flatten robot state (shouldn't be in policy!)
state = self._flatten_robot_state(obs["robot_state"])
# Policy-specific: Normalize with dataset stats
state = self.normalizer(state)
return state
# ✅ After: Clear separation
# Environment processor: Handles LIBERO's nested robot state
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
# Policy processor: Handles model requirements
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
```
### 2. **Flexibility and Reusability**
The same policy can work with different environment processors, and the same environment processor can work with different policies:
```python
# Use SmolVLA policy with LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
```
### 3. **Easier Experimentation**
Want to try different state representations for LIBERO? Just create a new processor:
```python
# Original: 8D state (pos + quat→axisangle + gripper)
@ProcessorStepRegistry.register("libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
gripper = robot_state["gripper"]["qpos"] # 2D
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
return state
# Experiment: Add velocity for better control
@ProcessorStepRegistry.register("libero_velocity_processor")
class LiberoVelocityProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
# Include velocities for 14D state
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
gripper_pos = robot_state["gripper"]["qpos"] # 2D
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
gripper_pos, gripper_vel], dim=-1) # 14D
return state
```
### 4. **Cleaner Environment Code**
Environments expose **all available data** without needing to know what downstream models will use:
```python
# LIBERO environment exposes full robot state
observation = {
"pixels": {"image": img, "image2": img2},
"robot_state": {
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
"gripper": {"qpos": ..., "qvel": ...},
"joints": {"pos": ..., "vel": ...}
}
}
# Environment processor decides what to use
# Policy processor handles model-specific transformations
```
## Using Environment Processors
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
from lerobot.envs.factory import make_env_pre_post_processors
from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
```
### Implementation in `envs/factory.py`
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:
```python
def eval_main(cfg: EvalPipelineConfig):
# Create environment
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
# Create policy
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
# Create policy processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
)
# Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
# Run evaluation with both processor types
eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor, # Environment-specific
env_postprocessor=env_postprocessor, # Environment-specific
preprocessor=preprocessor, # Policy-specific
postprocessor=postprocessor, # Policy-specific
n_episodes=cfg.eval.n_episodes,
)
```
## Example: LIBERO Environment Processor
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
from lerobot.processor.pipeline import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
**State Processing:**
- Extracts end-effector position (3D)
- Converts quaternion to axis-angle representation (3D)
- Extracts gripper joint positions (2D)
- Concatenates into 8D state vector
**Image Processing:**
- Rotates images 180° to match HuggingFaceVLA/libero convention
"""
def _process_observation(self, observation):
processed_obs = observation.copy()
# Process images: Flip 180° for camera convention
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = processed_obs[key]
img = torch.flip(img, dims=[2, 3]) # Flip H and W
processed_obs[key] = img
# Process robot_state: Flatten to 8D vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
eef_pos = robot_state["eef"]["pos"] # (B, 3)
eef_quat = robot_state["eef"]["quat"] # (B, 4)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
state = state.float()
processed_obs["observation.state"] = state
return processed_obs
```
### Why These Transformations?
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
- Selects the relevant components (pos, quat, gripper)
- Converts quaternion to axis-angle (more suitable for learning)
- Flattens to a single 8D vector that policies expect
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
## Adding Environment Processors for New Environments
To add environment processors for a new environment:
### 1. Create the Processor Step
```python
# In src/lerobot/processor/env_processor.py
@dataclass
@ProcessorStepRegistry.register(name="myenv_processor")
class MyEnvProcessorStep(ObservationProcessorStep):
"""Process observations from MyEnv."""
def _process_observation(self, observation):
processed = observation.copy()
# Your environment-specific transformations
if "myenv.specific.state" in processed:
state = processed.pop("myenv.specific.state")
# Transform to standard format
processed["observation.state"] = self._transform_state(state)
return processed
```
### 2. Update the Factory
```python
# In src/lerobot/envs/factory.py
def make_env_pre_post_processors(env_cfg: EnvConfig):
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### 3. Use in Evaluation
No changes needed! The evaluation script automatically uses the appropriate processor:
```bash
lerobot-eval \
--policy.path=lerobot/my_policy \
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
--eval.n_episodes=10
```
## Future: Environment Postprocessors
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
### Action Space Transformations
```python
@dataclass
class MyEnvActionPostprocessor(ProcessorStep):
"""Convert policy actions to environment-specific format."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Convert from Cartesian to joint space
if self.action_space == "joint":
action = self.ik_solver(action)
# Example: Apply environment-specific safety limits
action = torch.clamp(action, self.min_action, self.max_action)
transition["action"] = action
return transition
```
### Coordinate System Conversions
```python
@dataclass
class CoordinateTransformPostprocessor(ProcessorStep):
"""Transform actions between coordinate systems."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Policy outputs in world frame, env expects base frame
action = self.world_to_base_transform(action)
transition["action"] = action
return transition
```
## Best Practices
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
## Summary
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
- ✅ Enables easy experimentation with different state representations
- ✅ Allows policies to work seamlessly across different environments
- ✅ Keeps environment code focused on simulation/hardware interface
- ✅ Makes processor pipelines more maintainable and debuggable
- ✅ Follows the single responsibility principle
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
+1 -1
View File
@@ -82,7 +82,7 @@ For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
-5
View File
@@ -28,11 +28,6 @@ LIBERO is now part of our **multi-eval supported simulation**, meaning you can b
To Install LIBERO, after following LeRobot official instructions, just do:
`pip install -e ".[libero]"`
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero tag, you will have to do: `pip install "lerobot[libero]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
### Single-suite evaluation
Evaluate a policy on one LIBERO suite:
+188
View File
@@ -0,0 +1,188 @@
# Real-Time Chunking (RTC)
Real-Time Chunking (RTC) is an inference-time method that allows large, flow-matching based robotic policies, such as [Pi0](./pi0), [Pi0.5](./pi05), and [SmolVLA](./smolvla), to produce smooth, continuous, and reactive motion despite having high inference latency.
These policies generate chunks of future actions (e.g., 50 steps at a time) instead of single actions.
Because the models are large, producing each chunk takes longer than the time it takes the robot to execute it.
Naively executing chunks leads to problems such as pauses, jerky transitions, or sudden changes in strategy whenever the next chunk arrives late or disagrees with the previously executed actions.
RTC solves this by asynchronously generating the next chunk while the robot continues executing the current one, and by guiding the new chunk so it aligns smoothly with the portion of the previous chunk that has already been executed.
## How RTC Works (simplified)
RTC lets the robot think ahead while its still moving. When the robot is carrying out one chunk of actions, RTC starts creating the next chunk early.
But since the robot has already moved a bit by the time the new chunk is ready, RTC has to make sure the new chunk still lines up smoothly with what the robot is currently doing.
To do this, RTC treats the beginning of the new chunk like an inpainting or “fill-in-the-gaps” problem:
it gently adjusts the first part of the new chunk so it blends naturally with the robots ongoing motion. The result is no pauses, no sudden jumps.
In technical terms, RTC adds a guidance term to the flow-matching denoising process that forces the overlapping timesteps of the new chunk to stay close to the executed portion of the previous chunk, typically using a soft transition mask.
## Quick Start
### Installation
RTC is built into LeRobot. Just install the policy dependencies you need:
```bash
# For Pi0 or Pi0.5
pip install -e ".[pi]"
# For SmolVLA
pip install -e ".[smolvla]"
```
### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
from lerobot.policies.pi0 import PI0Policy, PI0Config
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.action_queue import ActionQueue
# Load Pi0 with RTC enabled
policy_cfg = PI0Config()
# Enable RTC
policy_cfg.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10, # How many steps to blend with previous chunk
max_guidance_weight=10.0, # How strongly to enforce consistency
prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
)
# Load the policy
policy = PI0Policy.from_pretrained("lerobot/pi0_base", policy_cfg=policy_cfg, device="cuda")
# Now use predict_action_chunk with RTC parameters
inference_delay = 4 # How many steps of inference latency, this values should be calculated based on the inference latency of the policy
# Initialize the action queue
action_queue = ActionQueue(policy_cfg.rtc_config)
# Start in a separate thread with the following function
def get_actions():
while True:
if should_get_actions:
prev_actions = action_queue.get_left_over()
obs = get_robot_observations(robot)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
action_queue.merge(
actions, actions, inference_delay
)
for step in range(num_steps):
action = action_queue.get()
# Execute the first N actions
execute_actions(action)
```
## Key Parameters
`RTCConfig` has the following parameters to tune:
**`execution_horizon`**: How many timesteps from the previous chunk to maintain consistency with. Higher values mean smoother transitions but potentially less reactivity.
Typical values: 8-12 steps
```python
RTCConfig(execution_horizon=10)
```
**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
**`prefix_attention_schedule`**: How to weight consistency across the overlap region.
- `LINEAR`: Linear decay from inference_delay to execution_horizon
- `EXP`: Exponential decay (recommended for getting started)
- `ONES`: Full weight across entire execution_horizon
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.
## Testing RTC Offline
Before running on a real robot, test RTC with dataset samples to visualize how it works:
```bash
python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \
--device=cuda
```
The script generates a visualization of the denoising process, comparing standard generation (left) with RTC (right). In the RTC plots, you can see how the first few steps (blue/purple lines) are guided to match the red ground truth trajectory (previous chunk's tail), ensuring a smooth transition between chunks.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/flow_matching.png"
alt="Denoising steps with and without RTC"
width="100%"
/>
</p>
## Testing RTC with a Real Robot
```bash
python examples/rtc/eval_with_real_robot.py \
--policy.path=${HF_USERNAME}/policy_repo_id \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120 \
--device=cuda
```
## How It Differs from the Async Inference in LeRobot
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
| Aspect | Async Inference | RTC |
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
**Use both together** for maximum smoothness and reactivity!
## Advanced: Debug Tracking
RTC includes built-in debug tracking to help you understand what's happening during inference:
```python
# Enable debug tracking
policy_cfg.rtc_config.debug = True
policy_cfg.rtc_config.debug_maxlen = 100
# After inference, access debug data
debug_data = policy.rtc_processor.get_debug_data()
# Visualize denoising steps, corrections, etc.
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
## References
- [Smooth-As-Butter Robot Policies](https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html) - Excellent technical explanation with real robot results
- [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/research/real_time_chunking) - Original paper and research
- [Kinetix RTC Implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) - Reference implementation from Physical Intelligence
+29
View File
@@ -0,0 +1,29 @@
#!/bin/bash
#SBATCH -J b1k-aggregate
#SBATCH -p hopper-cpu
#SBATCH --qos=high
#SBATCH -c 2
#SBATCH -t 20:00:00
#SBATCH --mem=4G
#SBATCH -D /admin/home/francesco_capuano/lerobot
#SBATCH -o /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%j.out
#SBATCH -e /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%j.err
set -euo pipefail
set -x
export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK:-1}
source "$HOME/.bashrc" 2>/dev/null || true
if ! command -v conda >/dev/null 2>&1; then
source "$HOME/miniconda3/etc/profile.d/conda.sh" 2>/dev/null || true
source "$HOME/anaconda3/etc/profile.d/conda.sh" 2>/dev/null || true
fi
conda activate lerobot
python examples/behavior_1k/aggregate_tasks_datasets.py \
--task-datasets-dir /fsx/francesco_capuano/behavior1k-v3 \
--aggregated-root /fsx/francesco_capuano/behavior1k-v3/behavior1k \
--num-tasks 50 \
--hf-user fracapuano \
--push-to-hub
@@ -0,0 +1,100 @@
"""Aggregate multiple task-specific LeRobot datasets into a single combined dataset."""
import argparse
import os
from pathlib import Path
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
parser = argparse.ArgumentParser(
description="Aggregate multiple task-specific datasets into a single LeRobot dataset"
)
parser.add_argument(
"--task-datasets-dir",
type=str,
required=True,
help="Directory containing individual task datasets (e.g., /path/to/behavior1k/)",
)
parser.add_argument(
"--aggregated-root",
type=str,
required=True,
help="Path where the aggregated dataset will be written",
)
parser.add_argument(
"--num-tasks",
type=int,
default=50,
help="Number of tasks to aggregate (default: 50)",
)
parser.add_argument(
"--task-start-idx",
type=int,
default=0,
help="Starting task index (default: 0)",
)
parser.add_argument(
"--hf-user",
type=str,
default=None,
help="HuggingFace username for repo IDs (defaults to HF_USER env var or 'lerobot')",
)
parser.add_argument(
"--aggregated-repo-id",
type=str,
default=None,
help="Repository ID for the aggregated dataset (defaults to {hf_user}/behavior1k)",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the aggregated dataset to the Hugging Face Hub",
)
args = parser.parse_args()
# Determine HF user
hf_user = args.hf_user or os.environ.get("HF_USER", "lerobot")
# Set default aggregated repo ID if not provided
aggregated_repo_id = args.aggregated_repo_id or f"{hf_user}/behavior1k"
# Generate task indices
task_indices = range(args.task_start_idx, args.task_start_idx + args.num_tasks)
# Generate repo IDs for individual tasks
repo_ids = [f"{hf_user}/behavior1k-task{i:04d}" for i in task_indices]
# Generate local paths for individual task datasets
task_datasets_dir = Path(args.task_datasets_dir)
roots = [task_datasets_dir / f"behavior1k-task{i:04d}" for i in task_indices]
# Aggregated dataset path
aggregated_root = Path(args.aggregated_root)
print(f"🔹 Aggregating {args.num_tasks} task datasets")
print(f"Task datasets directory: {task_datasets_dir}")
print(f"Aggregated output: {aggregated_root}")
print(f"Aggregated repo ID: {aggregated_repo_id}")
aggregate_datasets(
repo_ids=repo_ids,
roots=roots,
aggr_repo_id=aggregated_repo_id,
aggr_root=aggregated_root,
)
print("✅ Aggregation complete")
if args.push_to_hub:
print(f"📤 Pushing aggregated dataset to {aggregated_repo_id}")
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
ds.push_to_hub()
print("✅ Successfully pushed to hub")
if __name__ == "__main__":
main()
+38
View File
@@ -0,0 +1,38 @@
#!/bin/bash
#SBATCH -J b1k-convert
#SBATCH -p hopper-cpu # pick your partition
#SBATCH --qos=high
#SBATCH --array=0-49%8 # 50 tasks, max 8 running concurrently (conversion is I/O bound)
#SBATCH -c 1 # CPUs per conversion (tune as needed)
#SBATCH -t 2:00:00 # Time per conversion
#SBATCH --mem=3G # ~1.75GB for task 0, ~doubled for safety
#SBATCH -D /admin/home/francesco_capuano/lerobot
#SBATCH -o /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%A_%a.out
#SBATCH -e /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%A_%a.err
set -euo pipefail
set -x
export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK:-1} # avoid BLAS oversubscription
DATA_PATH="/fsx/francesco_capuano/behavior1k-2025-v21"
BASE_OUT="/fsx/francesco_capuano/behavior1k-v3"
mkdir -p "$BASE_OUT" logs
i="${SLURM_ARRAY_TASK_ID}"
OUT_DIR="$(printf "%s/behavior1k-task%04d" "$BASE_OUT" "$i")"
# activate your env if needed
source "$HOME/.bashrc" 2>/dev/null || true
if ! command -v conda >/dev/null 2>&1; then
source "$HOME/miniconda3/etc/profile.d/conda.sh" 2>/dev/null || true
source "$HOME/anaconda3/etc/profile.d/conda.sh" 2>/dev/null || true
fi
conda activate lerobot
python examples/behavior_1k/convert_to_lerobot_v3.py \
--data-path "$DATA_PATH" \
--new-repo "$OUT_DIR" \
--task-id "$i" \
--force-conversion \
--push-to-hub
+667
View File
@@ -0,0 +1,667 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Behavior Dataset to LeRobotDataset v3.0 format"""
import argparse
import json
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import pandas as pd
import pyarrow as pa
import tqdm
from datasets import Dataset, Features, Image
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
LEGACY_EPISODES_PATH,
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
flatten_dict,
get_file_size_in_mb,
get_parquet_file_size_in_mb,
get_parquet_num_frames,
load_info,
update_chunk_file_indices,
write_episodes,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.utils import init_logging
# script to convert one single task to v3.1
# TASK = 1
NEW_ROOT = Path("/fsx/jade_choghari/tmp/bb")
def fix_episode_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Performs several fixes to an underlying dataframe to make it LeRobotDataset-v3 compatible"""
# Inject per-episode frame_index if missing (0..N-1 within each episode)
if "frame_index" not in df.columns:
df["frame_index"] = range(len(df))
# Remove variable-length task_info feature (NOTE(fracapuano): change to padding at some point?)
if "observation.task_info" in df.columns:
df = df.drop(columns=["observation.task_info"])
# NOTE(fracapuano): tasks are ordered (and there is one task per file/dataset)
if "task_index" in df.columns:
df["task_index"] = 0
return df
def get_total_episodes_task(local_dir: Path, task_id: int, task_ranges: dict, step) -> int:
"""
Calculates the total number of episodes for a single, specified task.
"""
# Simply load the episodes for the task and count them.
episodes = legacy_load_episodes_task(
local_dir=local_dir, task_id=task_id, task_ranges=task_ranges, step=step
)
return len(episodes)
NUM_CAMERAS = 9
def get_total_frames_task(local_dir, meta_path, task_id: int, task_ranges: dict, step: int) -> int:
episodes_metadata = legacy_load_episodes_task(
local_dir=local_dir, task_id=task_id, task_ranges=task_ranges, step=step
)
total_frames = 0
# like 'duration'
for ep in episodes_metadata.values():
duration_s = ep["length"]
total_frames += int(duration_s)
return total_frames
def convert_info(
root, new_root, data_file_size_in_mb, video_file_size_in_mb, meta_path, task_id: int, task_ranges, step
):
info = load_info(root)
features = {**info["features"], **DEFAULT_FEATURES}
del features[
"observation.task_info"
] # variable-length task_info is not supported in LeRobotDataset v3.0!
info["codebase_version"] = "v3.0"
info["features"] = features
del info["total_videos"]
info["data_files_size_in_mb"] = data_file_size_in_mb
info["video_files_size_in_mb"] = video_file_size_in_mb
info["data_path"] = DEFAULT_DATA_PATH
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
info["fps"] = int(info["fps"])
for key in info["features"]:
if info["features"][key]["dtype"] == "video":
# already has fps in video_info
continue
info["features"][key]["fps"] = info["fps"]
info["total_episodes"] = get_total_episodes_task(root, task_id, task_ranges, step)
info["total_videos"] = info["total_episodes"] * NUM_CAMERAS
info["total_frames"] = get_total_frames_task(root, meta_path, task_id, task_ranges, step)
info["total_tasks"] = 1
write_info(info, new_root)
def load_jsonlines(fpath: Path) -> list[any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
# return tasks dict such that
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def convert_tasks(root, new_root, task_id: int):
tasks, _ = legacy_load_tasks(root)
if task_id not in tasks:
raise ValueError(f"Task ID {task_id} not found in tasks (available: {list(tasks.keys())})")
tasks = {task_id: tasks[task_id]}
# Tasks are ordered with 0..ntasks-1 in the converted dataset
task_indices = range(len(tasks.keys()))
task_strings = tasks.values()
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
write_tasks(df_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
dataframes = []
for file in paths_to_cat:
df = pd.read_parquet(file)
df = fix_episode_dataframe(df)
dataframes.append(df)
# Concatenate all DataFrames along rows
concatenated_df = pd.concat(dataframes, ignore_index=True)
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(image_keys) > 0:
schema = pa.Schema.from_pandas(concatenated_df)
features = Features.from_arrow_schema(schema)
for key in image_keys:
features[key] = Image()
schema = features.arrow_schema
else:
schema = None
concatenated_df.to_parquet(path, index=False, schema=schema)
def get_image_keys(root):
info = load_info(root)
features = info["features"]
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
return image_keys
def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int, task_index: int):
task_dir_name = f"task-{task_index:04d}"
data_dir = root / "data" / task_dir_name
ep_paths = sorted(data_dir.glob("*.parquet"))
image_keys = get_image_keys(root)
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
num_frames = 0
paths_to_cat = []
episodes_metadata = []
logging.info(f"Converting data files from {len(ep_paths)} episodes")
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
ep_metadata = {
"episode_index": ep_idx,
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"dataset_from_index": num_frames,
"dataset_to_index": num_frames + ep_num_frames,
}
size_in_mb += ep_size_in_mb
num_frames += ep_num_frames
episodes_metadata.append(ep_metadata)
# write 0-based episode index instead of custom episode index (otherwise breaks compatibility with LeRobotDataset)
tmp_df = pd.read_parquet(ep_path)
tmp_df["episode_index"] = ep_idx
tmp_df.to_parquet(ep_path)
ep_idx += 1
if size_in_mb < data_file_size_in_mb:
paths_to_cat.append(ep_path)
continue
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Reset for the next file
size_in_mb = ep_size_in_mb
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Write remaining data if any
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
return episodes_metadata
def convert_videos_of_camera(
root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int, task_index: int
):
# Access old paths to mp4
# videos_dir = root / "videos"
# ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
task_dir_name = f"task-{task_index:04d}"
videos_dir = root / "videos" / task_dir_name / video_key
ep_paths = sorted(videos_dir.glob("*.mp4"))
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
# Size limit would be exceeded, save current accumulation WITHOUT this episode
concatenate_video_files(
paths_to_cat,
new_root
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
)
# Update episodes metadata for the file we just saved
for i, _ in enumerate(paths_to_cat):
past_ep_idx = ep_idx - len(paths_to_cat) + i
episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
# Move to next file and start fresh with current episode
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
# Add current episode metadata
ep_metadata = {
"episode_index": ep_idx,
f"videos/{video_key}/chunk_index": chunk_idx, # Will be updated when file is saved
f"videos/{video_key}/file_index": file_idx, # Will be updated when file is saved
f"videos/{video_key}/from_timestamp": duration_in_s,
f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
}
episodes_metadata.append(ep_metadata)
# Add current episode to accumulation
paths_to_cat.append(ep_path)
size_in_mb += ep_size_in_mb
duration_in_s += ep_duration_in_s
ep_idx += 1
# Write remaining videos if any
if paths_to_cat:
concatenate_video_files(
paths_to_cat,
new_root
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
)
# Update episodes metadata for the final file
for i, _ in enumerate(paths_to_cat):
past_ep_idx = ep_idx - len(paths_to_cat) + i
episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
return episodes_metadata
def get_video_keys(root):
info = load_info(root)
features = info["features"]
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
return video_keys
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int, task_id: int):
logging.info(f"Converting videos from {root} to {new_root}")
video_keys = get_video_keys(root)
if len(video_keys) == 0:
return None
video_keys = sorted(video_keys)
eps_metadata_per_cam = []
for camera in video_keys:
eps_metadata = convert_videos_of_camera(root, new_root, camera, video_file_size_in_mb, task_id)
eps_metadata_per_cam.append(eps_metadata)
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
if len(set(num_eps_per_cam)) != 1:
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
episodes_metadata = []
num_cameras = len(video_keys)
num_episodes = num_eps_per_cam[0]
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"):
# Sanity check
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
ep_ids += [ep_idx]
if len(set(ep_ids)) != 1:
raise ValueError(f"All episode indices need to match ({ep_ids}).")
ep_dict = {}
for cam_idx in range(num_cameras):
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
episodes_metadata.append(ep_dict)
return episodes_metadata
def infer_task_episode_ranges(episodes_jsonl_path: Path) -> dict:
"""
Parse the Behavior-1K episodes.jsonl metadata and infer contiguous episode ranges per unique task.
Returns a dict:
{ task_id: { "task_string": ..., "ep_start": ..., "ep_end": ... } }
"""
task_ranges = {}
task_id = 0
current_task_str = None
ep_start = None
ep_end = None
with open(episodes_jsonl_path) as f:
for line in f:
if not line.strip():
continue
ep = json.loads(line)
ep_idx = ep["episode_index"]
task_str = ep["tasks"][0] if ep["tasks"] else "UNKNOWN"
if current_task_str is None:
current_task_str = task_str
ep_start = ep_idx
ep_end = ep_idx
elif task_str == current_task_str:
ep_end = ep_idx
else:
# close previous task group
task_ranges[task_id] = {
"task_string": current_task_str,
"ep_start": ep_start,
"ep_end": ep_end,
}
task_id += 1
# start new one
current_task_str = task_str
ep_start = ep_idx
ep_end = ep_idx
# store last task
if current_task_str is not None:
task_ranges[task_id] = {
"task_string": current_task_str,
"ep_start": ep_start,
"ep_end": ep_end,
}
return task_ranges
def legacy_load_episodes_task(local_dir: Path, task_id: int, task_ranges: dict, step: int = 10) -> dict:
"""
Load only the episodes belonging to a specific task, inferred automatically from episode ranges.
Args:
local_dir (Path): Root path containing legacy meta/episodes.jsonl
task_id (int): Which task to load (key from the inferred task_ranges dict)
task_ranges (dict): Mapping from infer_task_episode_ranges()
step (int): Episode index step (Behavior-1K = 10)
"""
all_episodes = legacy_load_episodes(local_dir)
# get the range for this task
if task_id not in task_ranges:
raise ValueError(f"Task id {task_id} not found in task_ranges")
ep_start = task_ranges[task_id]["ep_start"]
ep_end = task_ranges[task_id]["ep_end"]
task_episode_indices = range(ep_start, ep_end + step, step)
return {i: all_episodes[i] for i in task_episode_indices if i in all_episodes}
def legacy_load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def legacy_load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def legacy_load_episodes_stats_task(local_dir: Path, task_id: int, task_ranges: dict, step: int = 10) -> dict:
all_stats = legacy_load_episodes_stats(local_dir)
if task_id not in task_ranges:
raise ValueError(f"Task id {task_id} not found in task_ranges")
ep_start = task_ranges[task_id]["ep_start"]
ep_end = task_ranges[task_id]["ep_end"]
task_episode_indices = range(ep_start, ep_end + step, step)
return {i: all_stats[i] for i in task_episode_indices if i in all_stats}
def generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
):
num_episodes = len(episodes_metadata)
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
episodes_stats_vals = list(episodes_stats.values())
episodes_stats_keys = list(episodes_stats.keys())
for i in range(num_episodes):
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
ep_metadata = episodes_metadata[i]
ep_stats = episodes_stats_vals[i]
ep_ids_set = {
ep_legacy_metadata["episode_index"],
ep_metadata["episode_index"],
episodes_stats_keys[i],
}
if episodes_videos is None:
ep_video = {}
else:
ep_video = episodes_videos[i]
ep_ids_set.add(ep_video["episode_index"])
ep_dict = {
**ep_legacy_metadata,
**ep_video,
**ep_metadata,
**flatten_dict({"stats": ep_stats}),
}
# enforce contiguous indexing 0..n-1, but also stores the legacy episode index
ep_dict["episode_index"] = i
yield ep_dict
def convert_episodes_metadata(
root, new_root, episodes_metadata, task_id: int, task_ranges, episodes_video_metadata=None
):
logging.info(f"Converting episodes metadata from {root} to {new_root}")
# filter by task
episodes_legacy_metadata = legacy_load_episodes_task(root, task_id=task_id, task_ranges=task_ranges)
episodes_stats = legacy_load_episodes_stats_task(root, task_id=task_id, task_ranges=task_ranges)
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
if episodes_video_metadata is not None:
num_eps_set.add(len(episodes_video_metadata))
if len(num_eps_set) != 1:
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
# Single file approach: set meta indices to 0 for all rows and write once
ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
)
)
num_eps = len(ds_episodes)
# NOTE(fracapuano): for the size of the average dataset this is fine!
ds_episodes = ds_episodes.add_column("meta/episodes/chunk_index", [0] * num_eps)
ds_episodes = ds_episodes.add_column("meta/episodes/file_index", [0] * num_eps)
write_episodes(ds_episodes, new_root)
stats = aggregate_stats(list(episodes_stats.values()))
write_stats(stats, new_root)
def convert_dataset_local(
data_path: Path,
new_repo: Path,
task_id: int,
data_file_size_in_mb: int = DEFAULT_DATA_FILE_SIZE_IN_MB,
video_file_size_in_mb: int = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
force_conversion: bool = False,
):
"""
Convert a local dataset to v3.x format, task-by-task, without using the Hugging Face Hub.
Args:
data_path (Path): path to local dataset root (e.g. /fsx/.../2025-challenge-demos)
new_repo (Path): path where converted dataset will be written (e.g. /fsx/.../behavior1k_v3)
task_id (int): which task to convert (index)
data_file_size_in_mb (int): max size per data chunk
video_file_size_in_mb (int): max size per video chunk
force_conversion (bool): overwrite existing conversion if True
"""
root = Path(data_path)
new_root = Path(new_repo)
# Clean up if needed
if new_root.exists() and force_conversion:
shutil.rmtree(new_root)
new_root.mkdir(parents=True, exist_ok=True)
print(f"🔹 Starting conversion for task {task_id}")
print(f"Input root: {root}")
print(f"Output root: {new_root}")
# Infer task episode ranges
episodes_meta_path = root / "meta" / "episodes.jsonl"
task_ranges = infer_task_episode_ranges(episodes_meta_path)
convert_info(
root,
new_root,
data_file_size_in_mb,
video_file_size_in_mb,
episodes_meta_path,
task_id,
task_ranges,
step=10,
)
convert_tasks(root, new_root, task_id)
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb, task_index=task_id)
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb, task_id=task_id)
convert_episodes_metadata(
root,
new_root,
episodes_metadata,
task_id=task_id,
task_ranges=task_ranges,
episodes_video_metadata=episodes_videos_metadata,
)
print(f"✅ Conversion complete for task {task_id}")
print(f"Converted dataset written to: {new_root}")
if __name__ == "__main__":
import argparse
from pathlib import Path
init_logging()
parser = argparse.ArgumentParser(
description="Convert Behavior-1K tasks to LeRobot v3 format (local only)"
)
parser.add_argument(
"--data-path",
type=str,
required=True,
help="Path to the local Behavior-1K dataset (e.g. /fsx/francesco_capuano/.cache/behavior-1k/2025-challenge-demos)",
)
parser.add_argument(
"--new-repo",
type=str,
required=True,
help="Path to the output directory for the converted dataset",
)
parser.add_argument(
"--task-id",
type=int,
required=True,
help="Task index to convert (e.g. 0, 1, 2, ...)",
)
parser.add_argument(
"--data-file-size-in-mb",
type=int,
default=DEFAULT_DATA_FILE_SIZE_IN_MB,
help=f"Maximum size per data chunk (default: {DEFAULT_DATA_FILE_SIZE_IN_MB})",
)
parser.add_argument(
"--video-file-size-in-mb",
type=int,
default=DEFAULT_VIDEO_FILE_SIZE_IN_MB,
help=f"Maximum size per video chunk (default: {DEFAULT_VIDEO_FILE_SIZE_IN_MB})",
)
parser.add_argument(
"--force-conversion",
action="store_true",
help="Force overwrite of existing conversion output if present.",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the (converted) dataset to the hub.",
)
args = parser.parse_args()
if args.push_to_hub:
HF_USER = os.environ.get("HF_USER", "fracapuano")
if HF_USER is None:
raise ValueError(
"HF_USER environment variable is not set! Set before converting and pushing to hub."
)
convert_dataset_local(
data_path=Path(args.data_path),
new_repo=Path(args.new_repo),
task_id=args.task_id,
data_file_size_in_mb=args.data_file_size_in_mb,
video_file_size_in_mb=args.video_file_size_in_mb,
force_conversion=args.force_conversion,
)
if args.push_to_hub:
ds = LeRobotDataset(repo_id=f"{HF_USER}/behavior1k-task{args.task_id:04d}", root=args.new_repo)
ds.push_to_hub()
+27
View File
@@ -0,0 +1,27 @@
#!/bin/bash
#SBATCH -J b1k-download
#SBATCH -p hopper-cpu
#SBATCH --qos=high
#SBATCH -c 32 # CPUs per conversion (tune as needed)
#SBATCH -t 20:00:00 # Time per conversion
#SBATCH -D /admin/home/francesco_capuano/lerobot
#SBATCH -o /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%A.out
#SBATCH -e /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%A.err
set -euo pipefail
set -x
export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK:-1}
# activate your env if needed
source "$HOME/.bashrc" 2>/dev/null || true
if ! command -v conda >/dev/null 2>&1; then
source "$HOME/miniconda3/etc/profile.d/conda.sh" 2>/dev/null || true
source "$HOME/anaconda3/etc/profile.d/conda.sh" 2>/dev/null || true
fi
conda activate lerobot
python examples/behavior_1k/download_data.py \
--repo-id "behavior-1k/2025-challenge-demos" \
--local-dir "/fsx/francesco_capuano/behavior1k-2025-v21" \
--max-workers 32
+26
View File
@@ -0,0 +1,26 @@
import shutil
from huggingface_hub import snapshot_download
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, required=True)
parser.add_argument("--max-workers", type=int, default=8)
parser.add_argument("--local-dir", type=str, required=True)
parser.add_argument("--force-download", action="store_true")
args = parser.parse_args()
if args.force_download:
shutil.rmtree(args.local_dir, ignore_errors=True)
snapshot_download(
repo_id=args.repo_id,
repo_type="dataset",
force_download=args.force_download,
max_workers=args.max_workers,
local_dir=args.local_dir,
ignore_patterns=["annotations/*"], # NOTE(fracapuano): Dropping textual annotations right now
)
+41
View File
@@ -0,0 +1,41 @@
#!/bin/bash
#SBATCH -J b1k-upload
#SBATCH -p hopper-cpu
#SBATCH --qos=high
#SBATCH -c 1
#SBATCH -t 48:00:00
#SBATCH --mem=4G
#SBATCH --array=0-49%2
#SBATCH -D /admin/home/francesco_capuano/lerobot
#SBATCH -o /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%A_%a.out
#SBATCH -e /admin/home/francesco_capuano/lerobot/examples/behavior_1k/logs/%x-%A_%a.err
set -euo pipefail
set -x
export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK:-1}
source "$HOME/.bashrc" 2>/dev/null || true
if ! command -v conda >/dev/null 2>&1; then
source "$HOME/miniconda3/etc/profile.d/conda.sh" 2>/dev/null || true
source "$HOME/anaconda3/etc/profile.d/conda.sh" 2>/dev/null || true
fi
conda activate lerobot
# The SLURM_ARRAY_TASK_ID will be used as the task-id
TASK_ID=${SLURM_ARRAY_TASK_ID}
# Configuration
ROOT_PATH="/fsx/francesco_capuano/behavior1k-v3"
HF_USER="fracapuano"
# Limit upload workers to reduce network contention (default in HF Hub is 4)
# For I/O-bound uploads, 2-4 workers per task is optimal
NUM_WORKERS=2
echo "Task ${TASK_ID}: uploading with ${NUM_WORKERS} workers from ${ROOT_PATH}"
python examples/behavior_1k/upload_folders.py \
--task-id ${TASK_ID} \
--root-path ${ROOT_PATH} \
--hf-user ${HF_USER} \
--num-workers ${NUM_WORKERS}
+108
View File
@@ -0,0 +1,108 @@
import argparse
from pathlib import Path
from huggingface_hub import HfApi, upload_large_folder
def main():
parser = argparse.ArgumentParser(
description="Upload a folder to Hugging Face Hub using upload_large_folder"
)
parser.add_argument(
"--folder-path",
type=str,
required=False,
help="Path to the folder to upload (used if task-id is not provided)",
)
parser.add_argument(
"--repo-id",
type=str,
required=False,
help="Repository ID on Hugging Face Hub (e.g., 'username/repo-name'). If task-id is provided, will be constructed as '{hf-user}/behavior1k-task{task_id:04d}'",
)
parser.add_argument(
"--task-id",
type=int,
required=False,
help="Task index to upload (e.g., 0, 1, 2, ...). When provided, folder-path is constructed from root-path.",
)
parser.add_argument(
"--root-path",
type=str,
required=False,
help="Root path containing task folders (e.g., /fsx/user/behavior1k-v3). Used with --task-id to construct folder path.",
)
parser.add_argument(
"--hf-user",
type=str,
default=None,
help="Hugging Face username for constructing repo-id with task-id (default: from HF_USER env var or 'fracapuano')",
)
parser.add_argument(
"--create-repo", action="store_true", help="Create the repository if it doesn't exist"
)
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="Number of parallel workers for upload (default: 2). For I/O-bound uploads, use 1-4 to avoid network contention.",
)
args = parser.parse_args()
# Construct folder path and repo ID based on task-id or use provided values
if args.task_id is not None:
if not args.root_path:
raise ValueError("--root-path is required when --task-id is provided")
task_folder_name = f"behavior1k-task{args.task_id:04d}"
folder_path = Path(args.root_path) / task_folder_name
repo_id = f"{args.hf_user}/{task_folder_name}"
print(f"Task mode: uploading task {args.task_id}")
else:
if not args.folder_path or not args.repo_id:
raise ValueError(
"Either --task-id with --root-path, or both --folder-path and --repo-id must be provided"
)
folder_path = Path(args.folder_path)
repo_id = args.repo_id
# Validate folder path
if not folder_path.exists():
raise ValueError(f"Folder path does not exist: {folder_path}")
if not folder_path.is_dir():
raise ValueError(f"Path is not a directory: {folder_path}")
print(f"Uploading folder: {folder_path}")
print(f"Repository: {repo_id}")
# Create repository if requested
if args.create_repo:
api = HfApi()
print(f"Creating repository {repo_id}...")
try:
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
print("Repository created or already exists. Updating its contents")
except Exception as e:
print(f"Warning: Could not create repository: {e}")
# Upload the folder
print(f"Starting upload with {args.num_workers} parallel workers...")
try:
result = upload_large_folder(
folder_path=str(folder_path),
repo_id=repo_id,
repo_type="dataset",
num_workers=args.num_workers,
)
print("✓ Upload completed successfully!")
print(f"Commit URL: {result}")
except Exception as e:
print(f"✗ Upload failed: {e}")
raise
if __name__ == "__main__":
main()
@@ -15,16 +15,12 @@
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
from port_droid import DROID_SHARDS
class AggregateDatasets(PipelineStep):
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
+2 -2
View File
@@ -20,7 +20,7 @@ from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
from port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
+9 -3
View File
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
UploadDataset(repo_id, private=private),
],
"logging_dir": str(logs_dir / job_name),
}
@@ -267,6 +267,12 @@ def main():
default="1950M",
help="Memory per cpu that each worker will use.",
)
parser.add_argument(
"--private",
action="store_true",
default=False,
help="Whether to create a private repository.",
)
init_logging()
+951
View File
@@ -0,0 +1,951 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
This script takes two random samples from a dataset:
- Uses actions from the first sample as previous chunk
- Generates new actions for the second sample with and without RTC
It compares action predictions with and without RTC on dataset samples,
measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=EXP \
--seed=10
# Basic usage with pi0.5 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--device=mps
--seed=10
# Basic usage with pi0.5 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# Basic usage with pi0 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
--torch_compile_mode=reduce-overhead
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
"""
import gc
import logging
import os
import random
from dataclasses import dataclass, field
import numpy as np
import torch
try:
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
plt = None
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.factory import resolve_delta_timestamps
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def _check_matplotlib_available():
"""Check if matplotlib is available, raise helpful error if not."""
if not MATPLOTLIB_AVAILABLE:
raise ImportError(
"matplotlib is required for RTC debug visualizations. "
"Please install it by running:\n"
" uv pip install matplotlib"
)
@dataclass
class RTCEvalConfig(HubMixin):
"""Configuration for RTC evaluation."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Dataset configuration
dataset: DatasetConfig = field(default_factory=DatasetConfig)
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=10.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=True,
debug_maxlen=1000,
)
)
# Device configuration
device: str | None = field(
default=None,
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
)
# Output configuration
output_dir: str = field(
default="rtc_debug_output",
metadata={"help": "Directory to save debug visualizations"},
)
# Seed configuration
seed: int = field(
default=42,
metadata={"help": "Random seed for reproducibility"},
)
inference_delay: int = field(
default=4,
metadata={"help": "Inference delay for RTC"},
)
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required (--policy.path)")
# Auto-detect device if not specified
if self.device is None or self.device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
logging.info(f"Auto-detected device: {self.device}")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
class RTCEvaluator:
"""Evaluator for RTC on dataset samples."""
def __init__(self, cfg: RTCEvalConfig):
self.cfg = cfg
self.device = cfg.device
# Load dataset with proper delta_timestamps based on policy configuration
# Calculate delta_timestamps using the same logic as make_dataset factory
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
# Get dataset metadata to extract FPS
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
# Calculate delta_timestamps from policy's delta_indices
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
# Create dataset with calculated delta_timestamps
self.dataset = LeRobotDataset(
cfg.dataset.repo_id,
delta_timestamps=delta_timestamps,
)
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor
self.preprocessor, self.postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides={
"device_processor": {"device": self.device},
},
)
logging.info("=" * 80)
logging.info("Ready to run evaluation with sequential policy loading:")
logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
"""Initialize a single policy instance with specified RTC configuration.
Args:
name: Name identifier for logging purposes
rtc_enabled: Whether to enable RTC for this policy
rtc_debug: Whether to enable debug tracking for this policy
Returns:
Configured policy instance with optional torch.compile applied
"""
logging.info(f"Initializing {name}...")
# Load policy from pretrained
policy_class = get_policy_class(self.cfg.policy.type)
config = PreTrainedConfig.from_pretrained(self.cfg.policy.pretrained_path)
if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0":
config.compile_model = self.cfg.use_torch_compile
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config)
policy = policy.to(self.device)
policy.eval()
# Configure RTC
rtc_config = RTCConfig(
enabled=rtc_enabled,
execution_horizon=self.cfg.rtc.execution_horizon,
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
debug=rtc_debug,
debug_maxlen=self.cfg.rtc.debug_maxlen,
)
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
logging.info(f" RTC enabled: {rtc_enabled}")
logging.info(f" RTC debug: {rtc_debug}")
logging.info(f" Policy config: {config}")
# Apply torch.compile to predict_action_chunk method if enabled
if self.cfg.use_torch_compile:
policy = self._apply_torch_compile(policy, name)
logging.info(f"{name} initialized successfully")
return policy
def _apply_torch_compile(self, policy, policy_name: str):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
policy_name: Name for logging purposes
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logging.warning(
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
# Compile the predict_action_chunk method
# - Debug tracker is excluded from compilation via @torch._dynamo.disable
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": self.cfg.torch_compile_backend,
"mode": self.cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if self.cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
except Exception as e:
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
logging.warning(f" [{policy_name}] Continuing without torch.compile")
return policy
def _destroy_policy(self, policy, policy_name: str):
"""Explicitly destroy a policy and free all associated memory.
This method performs aggressive cleanup to ensure maximum memory is freed,
which is critical for large models (e.g., VLAs with billions of parameters).
Args:
policy: Policy instance to destroy
policy_name: Name for logging purposes
"""
logging.info(f" Destroying {policy_name} and freeing memory...")
try:
# Step 1: Move policy to CPU to free GPU/MPS memory
policy.cpu()
# Step 2: Delete the policy object
del policy
# Step 3: Force garbage collection to reclaim memory immediately
gc.collect()
# Step 4: Clear device-specific caches
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Ensure all operations complete
if torch.backends.mps.is_available():
torch.mps.empty_cache()
logging.info(f"{policy_name} destroyed and memory freed")
except Exception as e:
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
def run_evaluation(self):
"""Run evaluation on two random dataset samples using three separate policies.
Note: Policies are deinitalized after each step to free memory. Large models
(e.g., VLA models with billions of parameters) cannot fit three instances in
memory simultaneously. By deleting and garbage collecting after each step,
we ensure only one policy is loaded at a time.
"""
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logging.info(f"Output directory: {self.cfg.output_dir}")
logging.info("=" * 80)
logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}")
logging.info("=" * 80)
# Load two random samples from dataset
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader)
first_sample = next(loader_iter)
second_sample = next(loader_iter)
preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample)
# ============================================================================
# Step 1: Generate previous chunk using policy_prev_chunk
# ============================================================================
# This policy is only used to generate the reference chunk and then freed
logging.info("=" * 80)
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
logging.info("=" * 80)
# Initialize policy 1
policy_prev_chunk_policy = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
with torch.no_grad():
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
# Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
# ============================================================================
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 2
policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone()
policy_no_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
no_rtc_actions = policy_no_rtc_policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise,
)
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}")
# Destroy policy_no_rtc to free memory before loading policy_rtc
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
# ============================================================================
# Step 3: Generate actions WITH RTC using policy_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 3
policy_rtc_policy = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
policy_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
rtc_actions = policy_rtc_policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise_clone,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
)
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}")
# Save num_steps before destroying policy (needed for plotting)
try:
num_steps = policy_rtc_policy.config.num_steps
except Exception as e:
logging.error(f" Error getting num_steps: {e}")
num_steps = policy_rtc_policy.config.num_inference_steps
logging.warning(f" Using num_inference_steps: {num_steps} instead of num_steps")
# Destroy policy_rtc after final use
self._destroy_policy(policy_rtc_policy, "policy_rtc")
# Plot and save results
logging.info("=" * 80)
logging.info("Plotting results...")
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
# Plot final actions comparison
logging.info("=" * 80)
logging.info("Plotting final actions comparison...")
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
logging.info("=" * 80)
logging.info("Evaluation completed successfully")
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
"""Plot final action predictions comparison on a single chart.
Args:
rtc_actions: Final actions from RTC policy
no_rtc_actions: Final actions from non-RTC policy
prev_chunk_left_over: Previous chunk used as ground truth
"""
_check_matplotlib_available()
# Remove batch dimension if present
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
no_rtc_actions_plot = (
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
)
prev_chunk_plot = prev_chunk_left_over.cpu()
# Create figure with 6 subplots (one per action dimension)
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
# Plot each action dimension
for dim_idx, ax in enumerate(axes):
# Plot previous chunk (ground truth) in red
RTCDebugVisualizer.plot_waypoints(
[ax],
prev_chunk_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="red",
label="Previous Chunk (Ground Truth)",
linewidth=2.5,
alpha=0.8,
)
# Plot no-RTC actions in blue
RTCDebugVisualizer.plot_waypoints(
[ax],
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="blue",
label="No RTC",
linewidth=2,
alpha=0.7,
)
# Plot RTC actions in green
RTCDebugVisualizer.plot_waypoints(
[ax],
rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="green",
label="RTC",
linewidth=2,
alpha=0.7,
)
# Add vertical lines for inference delay and execution horizon
inference_delay = self.cfg.inference_delay
execution_horizon = self.cfg.rtc.execution_horizon
if inference_delay > 0:
ax.axvline(
x=inference_delay - 1,
color="orange",
linestyle="--",
alpha=0.5,
label=f"Inference Delay ({inference_delay})",
)
if execution_horizon > 0:
ax.axvline(
x=execution_horizon,
color="purple",
linestyle="--",
alpha=0.5,
label=f"Execution Horizon ({execution_horizon})",
)
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
# Set x-axis ticks to show all integer values
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
axes[-1].set_xlabel("Step", fontsize=10)
# Collect legend handles and labels from first subplot
handles, labels = axes[0].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right)
fig.legend(
unique_handles,
unique_labels,
loc="center right",
fontsize=9,
bbox_to_anchor=(1.0, 0.5),
framealpha=0.9,
)
# Save figure
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
fig.savefig(output_path, dpi=150, bbox_inches="tight")
logging.info(f"Saved final actions comparison to {output_path}")
plt.close(fig)
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
_check_matplotlib_available()
# Create side-by-side figures for denoising visualization
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
fig_x1t, axs_x1t = self._create_figure(
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
)
self._plot_denoising_steps_from_tracker(
rtc_tracked_steps,
axs_xt[:, 1], # Right column for x_t
axs_vt[:, 1], # Right column for v_t
axs_corr[:, 1], # Right column for correction
axs_x1t[:, 1], # Right column for x1_t
num_steps,
add_labels=True, # Add labels for RTC (right column)
)
self._plot_denoising_steps_from_tracker(
no_rtc_tracked_steps,
axs_xt[:, 0], # Left column for x_t
axs_vt[:, 0], # Left column for v_t
axs_corr[:, 0], # Left column for correction
axs_x1t[:, 0], # Left column for x1_t
num_steps,
add_labels=False, # No labels for No RTC (left column)
)
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x_t axes (no labels for left column)
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
)
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
)
# Add legends outside the plot area for each figure
self._add_figure_legend(fig_xt, axs_xt)
self._add_figure_legend(fig_vt, axs_vt)
self._add_figure_legend(fig_corr, axs_corr)
self._add_figure_legend(fig_x1t, axs_x1t)
# Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
def _create_figure(self, title):
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
fig.suptitle(title, fontsize=16)
for ax in axs[:, 0]:
ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
for ax in axs[:, 1]:
ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
return fig, axs
def _add_figure_legend(self, fig, axs):
"""Add a legend outside the plot area on the right side.
Args:
fig: Matplotlib figure to add legend to
axs: Array of axes to collect legend handles from
"""
# Collect all handles and labels from the first row of axes (right column)
handles, labels = axs[0, 1].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right, close to charts)
if unique_handles:
fig.legend(
unique_handles,
unique_labels,
loc="center left",
fontsize=8,
bbox_to_anchor=(0.87, 0.5),
framealpha=0.9,
ncol=1,
)
def _save_figure(self, fig, path):
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
fig.savefig(path, dpi=150, bbox_inches="tight")
logging.info(f"Saved figure to {path}")
plt.close(fig)
def _plot_denoising_steps_from_tracker(
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
):
"""Plot denoising steps from tracker data.
Args:
tracked_steps: List of DebugStep objects containing debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap
add_labels: Whether to add legend labels for the plots
"""
logging.info("=" * 80)
logging.info(f"Plotting {len(tracked_steps)} steps")
debug_steps = tracked_steps
if not debug_steps:
return
# Define colors for different denoise steps (using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)]
label = f"Step {step_idx}" if add_labels else None
# Plot x_t
if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=label
)
# Plot v_t
if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=label
)
# Plot correction on separate axes
if debug_step.correction is not None:
RTCDebugVisualizer.plot_waypoints(
corr_axs,
debug_step.correction,
start_from=0,
color=color,
label=label,
)
# Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None:
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
RTCDebugVisualizer.plot_waypoints(
x1t_axs,
debug_step.x1_t,
start_from=0,
color=color,
label=x1t_label,
)
# Plot error in orange dashed
if x1t_axs is not None and debug_step.err is not None:
error_chunk = (
debug_step.err[0].cpu().numpy()
if len(debug_step.err.shape) == 3
else debug_step.err.cpu().numpy()
)
num_dims = min(error_chunk.shape[-1], 6)
error_label = f"error Step {step_idx}" if add_labels else None
for j in range(num_dims):
x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
error_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
label=error_label,
)
# Recalculate axis limits after plotting to ensure proper scaling
self._rescale_axes(xt_axs)
self._rescale_axes(vt_axs)
self._rescale_axes(corr_axs)
self._rescale_axes(x1t_axs)
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
Args:
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
num_steps: Total number of denoising steps for colormap
"""
debug_steps = no_rtc_tracked_steps
if not debug_steps:
return
# Plot only the final x_t step as orange dashed line
final_step = debug_steps[-1]
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
if final_step.x_t is not None:
x_t_chunk = (
final_step.x_t[0].cpu().numpy()
if len(final_step.x_t.shape) == 3
else final_step.x_t.cpu().numpy()
)
num_dims = min(x_t_chunk.shape[-1], 6)
for j in range(num_dims):
xt_axs[j].plot(
np.arange(0, x_t_chunk.shape[0]),
x_t_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
linewidth=2,
label="No RTC (final)" if j == 0 else "",
)
def _rescale_axes(self, axes):
"""Rescale axes to show all data with proper margins.
Args:
axes: Array of matplotlib axes to rescale
"""
for ax in axes:
ax.relim()
ax.autoscale_view()
# Add 10% margin to y-axis for better visualization
ylim = ax.get_ylim()
y_range = ylim[1] - ylim[0]
if y_range > 0: # Avoid division by zero
margin = y_range * 0.1
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
# Set x-axis ticks to show all integer values
xlim = ax.get_xlim()
max_len = int(xlim[1]) + 1
if max_len > 0:
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
@parser.wrap()
def main(cfg: RTCEvalConfig):
"""Main entry point for RTC evaluation."""
# Set random seed for reproducibility
set_seed(cfg.seed)
init_logging()
logging.info("=" * 80)
logging.info("RTC Dataset Evaluation")
logging.info(f"Config: {cfg}")
logging.info("=" * 80)
evaluator = RTCEvaluator(cfg)
evaluator.run_evaluation()
if __name__ == "__main__":
main()
+549
View File
@@ -0,0 +1,549 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
+1 -1
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.1"
version = "0.4.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
+7
View File
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
class PolicyFeature:
type: FeatureType
shape: tuple[int, ...]
class RTCAttentionSchedule(str, Enum):
ZEROS = "ZEROS"
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
+14 -18
View File
@@ -39,6 +39,7 @@ from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -962,28 +963,23 @@ def _copy_data_with_feature_changes(
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.meta.root)
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
# Map file paths to episode indices to extract chunk/file indices
file_to_episodes: dict[Path, set[int]] = {}
for ep_idx in range(dataset.meta.total_episodes):
file_path = dataset.meta.get_data_file_path(ep_idx)
if file_path not in file_to_episodes:
file_to_episodes[file_path] = set()
file_to_episodes[file_path].add(ep_idx)
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
frame_idx = 0
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Get chunk_idx and file_idx from the source file's first episode
episodes_in_file = file_to_episodes[src_path]
first_ep_idx = min(episodes_in_file)
src_ep = dataset.meta.episodes[first_ep_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
relative_path = src_path.relative_to(dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
@@ -1009,7 +1005,7 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_slice
frame_idx = end_idx
# Write using the preserved chunk_idx and file_idx from source
# Write using the same chunk/file structure as source
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
+42 -9
View File
@@ -712,6 +712,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
# Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
# Build a mapping: absolute_index -> relative_index_in_filtered_dataset
self._absolute_to_relative_idx = None
if self.episodes is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -830,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -847,10 +856,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Determine requested episodes
if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
# Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
@@ -932,7 +939,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -940,11 +951,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
"""
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
Args:
query_indices: Dict mapping keys to index lists to retrieve
Returns:
Dict with stacked tensors of queried data (video keys excluded)
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
# Map absolute indices to relative indices if needed
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -1483,6 +1515,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj._absolute_to_relative_idx = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
+18 -4
View File
@@ -28,6 +28,7 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
+57 -9
View File
@@ -21,7 +21,22 @@ import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import (
ACTION,
LIBERO_KEY_EEF_MAT,
LIBERO_KEY_EEF_POS,
LIBERO_KEY_EEF_QUAT,
LIBERO_KEY_GRIPPER_QPOS,
LIBERO_KEY_GRIPPER_QVEL,
LIBERO_KEY_JOINTS_POS,
LIBERO_KEY_JOINTS_VEL,
LIBERO_KEY_PIXELS_AGENTVIEW,
LIBERO_KEY_PIXELS_EYE_IN_HAND,
OBS_ENV_STATE,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
)
@dataclass
@@ -246,28 +261,61 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["pixels/agentview_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(4,),
)
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(3, 3),
)
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
+39
View File
@@ -14,12 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -33,6 +37,41 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
+69 -21
View File
@@ -28,7 +28,6 @@ import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.transform_utils import quat2axisangle
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
@@ -175,11 +174,36 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(
low=AGENT_POS_LOW,
high=AGENT_POS_HIGH,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
"robot_state": spaces.Dict(
{
"eef": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
"quat": spaces.Box(
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
),
"mat": spaces.Box(
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
),
}
),
"gripper": spaces.Dict(
{
"qpos": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
"qvel": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
}
),
"joints": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
}
),
}
),
}
)
@@ -191,6 +215,7 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -212,23 +237,48 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
eef_pos = raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
joint_pos = raw_obs.get("robot0_joint_pos")
joint_vel = raw_obs.get("robot0_joint_vel")
obs = {
"pixels": images,
"robot_state": {
"eef": {
"pos": eef_pos, # (3,)
"quat": eef_quat, # (4,)
"mat": eef_mat, # (3, 3)
},
"gripper": {
"qpos": gripper_qpos, # (2,)
"qvel": gripper_qvel, # (2,)
},
"joints": {
"pos": joint_pos, # (7,)
"vel": joint_vel, # (7,)
},
},
}
if self.obs_type == "pixels":
return {"pixels": images.copy()}
if self.obs_type == "pixels_agent_pos":
return {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
# Validate required fields are present
if eef_pos is None or eef_quat is None or gripper_qpos is None:
raise ValueError(
f"Missing required robot state fields in raw observation. "
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
f"gripper_qpos={gripper_qpos is not None}"
)
return obs
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -355,12 +405,10 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
+20 -6
View File
@@ -29,10 +29,22 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
def _convert_nested_dict(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _convert_nested_dict(v)
elif isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
else:
result[k] = v
return result
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -78,12 +90,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "agent_pos" in observations:
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -15
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0 PyTorch model."""
def __init__(self, config: PI0Config):
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI0 model
self.model = PI0Pytorch(config)
self.init_rtc_processor()
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
# Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
# Sample actions using the model (pass through RTC kwargs)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@PreTrainedConfig.register_subclass("pi05")
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -14
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model."""
def __init__(self, config: PI05Config):
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI05 model
self.model = PI05Pytorch(config)
self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
+38
View File
@@ -0,0 +1,38 @@
# Real-Time Chunking (RTC)
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
---
## Citation
If you use Real-Time Chunking in your work, please cite:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2025realtimeexecutionactionchunking,
title={Real-Time Execution of Action Chunking Flow Policies},
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
year={2025},
eprint={2506.07339},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2506.07339},
}
```
---
## License
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
+219
View File
@@ -0,0 +1,219 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action queue management for Real-Time Chunking (RTC).
This module provides ActionQueue, a thread-safe queue for managing action chunks
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
handling action merging and leftover tracking.
"""
import logging
from threading import Lock
import torch
from torch import Tensor
from lerobot.policies.rtc.configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
class ActionQueue:
"""Thread-safe queue for managing action chunks in real-time control.
This queue handles two types of action sequences:
- Original actions: Used for RTC to compute leftovers from previous chunks
- Processed actions: Post-processed actions ready for robot execution
The queue operates in two modes:
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
Args:
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
Attributes:
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
last_index (int): Current consumption index in the queue.
"""
def __init__(self, cfg: RTCConfig):
"""Initialize the action queue.
Args:
cfg: RTC configuration controlling queue behavior.
"""
self.queue = None # Processed actions for robot rollout
self.original_queue = None # Original actions for RTC
self.lock = Lock()
self.last_index = 0
self.cfg = cfg
def get(self) -> Tensor | None:
"""Get the next action from the queue.
Returns:
Tensor | None: The next action (action_dim,) or None if queue is empty.
Returns a clone to prevent external modifications.
"""
with self.lock:
if self.queue is None or self.last_index >= len(self.queue):
return None
action = self.queue[self.last_index]
self.last_index += 1
return action.clone()
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
These are the unconsumed actions from the current chunk, which will be
used by RTC to compute corrections for the next chunk.
Returns:
Tensor | None: Remaining original actions (remaining_steps, action_dim),
or None if no original queue exists.
"""
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
):
"""Merge new actions into the queue.
This method operates differently based on RTC mode:
- RTC enabled: Replaces the queue, accounting for inference delay
- RTC disabled: Appends to the queue, maintaining continuity
Args:
original_actions: Unprocessed actions from policy (time_steps, action_dim).
processed_actions: Post-processed actions for robot (time_steps, action_dim).
real_delay: Number of time steps of inference delay.
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
return
self._append_actions_queue(original_actions, processed_actions)
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
"""Replace the queue with new actions (RTC mode).
Discards the first `real_delay` actions since they correspond to the time
spent during inference, when the robot was executing previous actions.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
self.last_index = 0
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
"""Append new actions to the queue (non-RTC mode).
Removes already-consumed actions and appends new ones, maintaining
queue continuity without replacement.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
"""
if self.queue is None:
self.original_queue = original_actions.clone()
self.queue = processed_actions.clone()
return
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
self.original_queue = self.original_queue[self.last_index :]
self.queue = torch.cat([self.queue, processed_actions.clone()])
self.queue = self.queue[self.last_index :]
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
number of actions consumed during inference.
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
Based on:
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
"""
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule
@dataclass
class RTCConfig:
"""Configuration for Real Time Chunking (RTC) inference.
RTC improves real-time inference by treating chunk generation as an inpainting problem,
strategically handling overlapping timesteps between action chunks using prefix attention.
"""
# Infrastructure
enabled: bool = False
# Core RTC settings
# Todo change to exp
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
max_guidance_weight: float = 10.0
execution_horizon: int = 10
# Debug settings
debug: bool = False
debug_maxlen: int = 100
def __post_init__(self):
"""Validate RTC configuration parameters."""
if self.max_guidance_weight <= 0:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
+233
View File
@@ -0,0 +1,233 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Debug information handler for Real-Time Chunking (RTC)."""
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
@dataclass
class DebugStep:
"""Container for debug information from a single denoising step.
Attributes:
step_idx (int): Step index/counter.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
time (float | Tensor | None): Time parameter.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
metadata (dict[str, Any]): Additional metadata.
"""
step_idx: int = 0
x_t: Tensor | None = None
v_t: Tensor | None = None
x1_t: Tensor | None = None
correction: Tensor | None = None
err: Tensor | None = None
weights: Tensor | None = None
guidance_weight: float | Tensor | None = None
time: float | Tensor | None = None
inference_delay: int | None = None
execution_horizon: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
"""Convert debug step to dictionary.
Args:
include_tensors (bool): If True, include tensor values. If False, only include
tensor statistics (shape, mean, std, min, max).
Returns:
Dictionary representation of the debug step.
"""
result = {
"step_idx": self.step_idx,
"guidance_weight": (
self.guidance_weight.item()
if isinstance(self.guidance_weight, Tensor)
else self.guidance_weight
),
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
"inference_delay": self.inference_delay,
"execution_horizon": self.execution_horizon,
"metadata": self.metadata.copy(),
}
# Add tensor information
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
for field_name in tensor_fields:
tensor = getattr(self, field_name)
if tensor is not None:
if include_tensors:
result[field_name] = tensor.detach().cpu()
else:
result[f"{field_name}_stats"] = {
"shape": tuple(tensor.shape),
"mean": tensor.mean().item(),
"std": tensor.std().item(),
"min": tensor.min().item(),
"max": tensor.max().item(),
}
return result
class Tracker:
"""Collects and manages debug information for RTC processing.
This tracker stores debug information from recent denoising steps in a dictionary,
using time as the key for efficient lookups and updates.
Args:
enabled (bool): Whether debug collection is enabled.
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
"""
def __init__(self, enabled: bool = False, maxlen: int = 100):
self.enabled = enabled
self._steps = {} if enabled else None # Dictionary with time as key
self._maxlen = maxlen
self._step_counter = 0
def reset(self) -> None:
"""Clear all recorded debug information."""
if self.enabled and self._steps is not None:
self._steps.clear()
self._step_counter = 0
@torch._dynamo.disable
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Track debug information for a denoising step at a given time.
If a step with the given time already exists, it will be updated with the new data.
Otherwise, a new step will be created. Only non-None fields are updated/set.
Note: This method is excluded from torch.compile to avoid graph breaks from
operations like .item() which are incompatible with compiled graphs.
Args:
time (float | Tensor): Time parameter - used as the key to identify the step.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction.
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
**metadata: Additional metadata to store.
"""
if not self.enabled:
return
# Convert time to float and round to avoid float precision issues
time_value = time.item() if isinstance(time, Tensor) else time
time_key = round(time_value, 6) # Use rounded time as dictionary key
# Check if step with this time already exists
if time_key in self._steps:
# Update existing step with non-None fields
existing_step = self._steps[time_key]
if x_t is not None:
existing_step.x_t = x_t.detach().clone()
if v_t is not None:
existing_step.v_t = v_t.detach().clone()
if x1_t is not None:
existing_step.x1_t = x1_t.detach().clone()
if correction is not None:
existing_step.correction = correction.detach().clone()
if err is not None:
existing_step.err = err.detach().clone()
if weights is not None:
existing_step.weights = weights.detach().clone()
if guidance_weight is not None:
existing_step.guidance_weight = guidance_weight
if inference_delay is not None:
existing_step.inference_delay = inference_delay
if execution_horizon is not None:
existing_step.execution_horizon = execution_horizon
if metadata:
existing_step.metadata.update(metadata)
else:
# Create new step
step = DebugStep(
step_idx=self._step_counter,
x_t=x_t.detach().clone() if x_t is not None else None,
v_t=v_t.detach().clone() if v_t is not None else None,
x1_t=x1_t.detach().clone() if x1_t is not None else None,
correction=correction.detach().clone() if correction is not None else None,
err=err.detach().clone() if err is not None else None,
weights=weights.detach().clone() if weights is not None else None,
guidance_weight=guidance_weight,
time=time_value,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
metadata=metadata,
)
# Add to dictionary
self._steps[time_key] = step
self._step_counter += 1
# Enforce maxlen if set
if self._maxlen is not None and len(self._steps) > self._maxlen:
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
oldest_key = next(iter(self._steps))
del self._steps[oldest_key]
def get_all_steps(self) -> list[DebugStep]:
"""Get all recorded debug steps.
Returns:
List of all DebugStep objects (may be empty if disabled).
"""
if not self.enabled or self._steps is None:
return []
return list(self._steps.values())
def __len__(self) -> int:
"""Return the number of recorded debug steps."""
if not self.enabled or self._steps is None:
return 0
return len(self._steps)
@@ -0,0 +1,113 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Visualization utilities for RTC debug information."""
import torch
class RTCDebugVisualizer:
"""Visualizer for RTC debug information.
This class provides methods to visualize debug information collected by the Tracker,
including corrections, errors, weights, and guidance weights over denoising steps.
"""
@staticmethod
def plot_waypoints(
axes,
tensor,
start_from: int = 0,
color: str = "blue",
label: str = "",
alpha: float = 0.7,
linewidth: float = 2,
marker: str | None = None,
markersize: int = 4,
):
"""Plot trajectories across multiple dimensions.
This function plots a tensor's values across time for multiple dimensions,
with each dimension plotted on a separate axis.
Args:
axes: Array of matplotlib axes (one for each dimension).
tensor: The tensor to plot (can be torch.Tensor or numpy array).
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
start_from: Starting index for the x-axis.
color: Color for the plot lines.
label: Label for the plot legend.
alpha: Transparency level for the plot.
linewidth: Width of the plot lines.
marker: Marker style for data points (e.g., 'o', 's', '^').
markersize: Size of the markers.
"""
import numpy as np
# Handle None tensor
if tensor is None:
return
# Convert tensor to numpy if needed
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
# Handle different tensor shapes
if tensor_np.ndim == 3:
# If batch dimension present, take first batch
tensor_np = tensor_np[0]
elif tensor_np.ndim == 1:
# If 1D, reshape to (time_steps, 1)
tensor_np = tensor_np.reshape(-1, 1)
# Get dimensions
time_steps, num_dims = tensor_np.shape
# Create x-axis indices
x_indices = np.arange(start_from, start_from + time_steps)
# Plot each dimension on its corresponding axis
num_axes = len(axes) if hasattr(axes, "__len__") else 1
for dim_idx in range(min(num_dims, num_axes)):
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
# Plot the trajectory
if marker:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
marker=marker,
markersize=markersize,
)
else:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
)
# Add grid and labels if not already present
if not ax.xaxis.get_label().get_text():
ax.set_xlabel("Step", fontsize=10)
if not ax.yaxis.get_label().get_text():
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
from collections import deque
import numpy as np
class LatencyTracker:
"""Tracks recent latencies and provides max/percentile queries.
Args:
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
"""
def __init__(self, maxlen: int = 100):
self._values = deque(maxlen=maxlen)
self.reset()
def reset(self) -> None:
"""Clear all recorded latencies."""
self._values.clear()
self.max_latency = 0.0
def add(self, latency: float) -> None:
"""Add a latency sample (seconds)."""
# Ensure numeric and non-negative
val = float(latency)
if val < 0:
return
self._values.append(val)
self.max_latency = max(self.max_latency, val)
def __len__(self) -> int:
return len(self._values)
def max(self) -> float | None:
"""Return the maximum latency or None if empty."""
return self.max_latency
def percentile(self, q: float) -> float | None:
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
if not self._values:
return 0.0
q = float(q)
if q <= 0.0:
return min(self._values)
if q >= 1.0:
return self.max_latency
vals = np.array(list(self._values), dtype=np.float32)
return float(np.quantile(vals, q))
def p95(self) -> float | None:
"""Return the 95th percentile latency or None if empty."""
return self.percentile(0.95)
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real-Time Chunking (RTC) implementation for LeRobot.
Based on Physical Intelligence's Kinetix implementation:
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
"""
import logging
import math
import torch
from torch import Tensor
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_tracker import Tracker
logger = logging.getLogger(__name__)
class RTCProcessor:
"""Real-Time Chunking processor for action chunking policies.
This class implements RTC techniques including velocity calculation,
prefix attention, and adaptive chunk processing.
"""
def __init__(self, rtc_config: RTCConfig):
self.rtc_config = rtc_config
self.tracker = None
if rtc_config.debug:
self.tracker = Tracker(
enabled=rtc_config.debug,
maxlen=rtc_config.debug_maxlen,
)
# ====================== Tracker Proxy Methods ======================
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Proxy method to track debug information.
If tracker is None or disabled, this method does nothing.
Otherwise, it forwards the call to tracker.track().
"""
if self.tracker is not None:
self.tracker.track(
time=time,
x_t=x_t,
v_t=v_t,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
**metadata,
)
def get_all_debug_steps(self) -> list:
"""Get all debug steps from tracker.
Returns empty list if tracker is disabled or None.
"""
if self.tracker is not None:
return self.tracker.get_all_steps()
return []
def is_debug_enabled(self) -> bool:
"""Check if debug tracking is enabled.
Returns True if tracker exists and is enabled.
"""
return self.tracker is not None and self.tracker.enabled
def reset_tracker(self) -> None:
"""Reset the tracker, clearing all recorded steps.
Does nothing if tracker is None.
"""
if self.tracker is not None:
self.tracker.reset()
# ====================== End Tracker Proxy Methods ======================
def denoise_step(
self,
x_t,
prev_chunk_left_over,
inference_delay,
time,
original_denoise_step_partial,
execution_horizon=None,
) -> Tensor:
"""RTC guidance wrapper around an existing denoiser.
This method wraps an original denoising callable that only takes ``x_t`` and
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
(RTC) prefix guidance using the leftover prefix from the previous chunk.
Args:
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
is applied and the method returns ``v_t`` from the original denoiser.
inference_delay (int): Number of timesteps from the prefix to use for guidance.
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
broadcastable with ``x_t``.
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
computes the base denoised velocity given only ``x_t``.
execution_horizon (int | None): Horizon used to build prefix weights. If
``None``, defaults to ``self.rtc_config.execution_horizon``.
Returns:
Tensor: Guided velocity with the same shape as ``v_t``.
Notes:
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
right-padded with zeros to match ``T``.
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
and broadcast to ``(B, T, A)``.
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
``error = (prev_chunk_left_over - x1_t) * weights``.
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
Reference:
https://www.physicalintelligence.company/download/real_time_chunking.pdf
"""
# In the original implementation, the time goes from 0 to 1 and
# In our implementation, the time goes from 1 to 0
# So we need to invert the time
tau = 1 - time
if prev_chunk_left_over is None:
# First step, no guidance - return v_t
v_t = original_denoise_step_partial(x_t)
return v_t
x_t = x_t.clone().detach()
squeezed = False
if len(x_t.shape) < 3:
# Add batch dimension
x_t = x_t.unsqueeze(0)
squeezed = True
if len(prev_chunk_left_over.shape) < 3:
# Add batch dimension
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
if execution_horizon is None:
execution_horizon = self.rtc_config.execution_horizon
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
# because there is nothing to merge
if execution_horizon > prev_chunk_left_over.shape[1]:
execution_horizon = prev_chunk_left_over.shape[1]
batch_size = x_t.shape[0]
action_chunk_size = x_t.shape[1]
action_dim = x_t.shape[2]
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
prev_chunk_left_over = padded
assert prev_chunk_left_over.shape == x_t.shape, (
"The padded previous chunk must be the same size as the input tensor"
)
weights = (
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
.to(x_t.device)
.unsqueeze(0)
.unsqueeze(-1)
)
with torch.enable_grad():
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
x1_t = x_t - time * v_t # noqa: N806
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
result = v_t - guidance_weight * correction
# Remove the batch dimension if it was added
if squeezed:
result = result.squeeze(0)
correction = correction.squeeze(0)
x1_t = x1_t.squeeze(0)
err = err.squeeze(0)
self.track(
time=time,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
return result
def get_prefix_weights(self, start, end, total):
start = min(start, end)
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
weights = torch.zeros(total)
weights[:start] = 1.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
weights = torch.ones(total)
weights[end:] = 0.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
lin_weights = self._linweights(start, end, total)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
lin_weights = self._linweights(start, end, total)
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
return weights
def _linweights(self, start, end, total):
skip_steps_at_end = max(total - end, 0)
linspace_steps = total - skip_steps_at_end - start
if end <= start or linspace_steps <= 0:
return torch.tensor([])
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
def _add_trailing_zeros(self, weights, total, end):
zeros_len = total - end
if zeros_len <= 0:
return weights
zeros = torch.zeros(zeros_len)
return torch.cat([weights, zeros])
def _add_leading_ones(self, weights, start, total):
ones_len = min(start, total)
if ones_len <= 0:
return weights
ones = torch.ones(ones_len)
return torch.cat([ones, weights])
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
def __post_init__(self):
super().__post_init__()
+101 -19
View File
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.utils.utils import get_safe_dtype
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.model = VLAFlowMatching(config)
self.init_rtc_processor()
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
self.reset()
def reset(self):
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Lets create processor if the config provided
# If RTC is not enabled - we still can track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# In case of calling init_rtc_processor after the model is created
# We need to set the rtc_processor to the model
# During the normal initialization process the model is not created yet
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def _get_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def predict_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
actions = self._get_action_chunk(batch, noise, **kwargs)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
if self._check_get_actions_condition():
actions = self._get_action_chunk(batch, noise)
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
return self._queues[ACTION].popleft()
def _check_get_actions_condition(self) -> bool:
return len(self._queues[ACTION]) == 0
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
"""
def __init__(self, config: SmolVLAConfig):
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
+154
View File
@@ -0,0 +1,154 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result
+33 -1
View File
@@ -71,7 +71,7 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
@@ -94,6 +94,8 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
@@ -165,11 +167,19 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -239,6 +249,8 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -319,6 +331,8 @@ def eval_policy(
rollout_data = rollout(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
@@ -517,10 +531,16 @@ def eval_main(cfg: EvalPipelineConfig):
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -561,6 +581,8 @@ def eval_one(
env: gym.vector.VectorEnv,
*,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -576,6 +598,8 @@ def eval_one(
task_result = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -600,6 +624,8 @@ def run_one(
env,
*,
policy,
env_preprocessor,
env_postprocessor,
preprocessor,
postprocessor,
n_episodes: int,
@@ -622,6 +648,8 @@ def run_one(
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -639,6 +667,8 @@ def run_one(
def eval_policy_all(
envs: dict[str, dict[int, gym.vector.VectorEnv]],
policy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -694,6 +724,8 @@ def eval_policy_all(
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
+6 -1
View File
@@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
@@ -259,6 +259,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
@@ -274,6 +276,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
@@ -384,6 +387,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy),
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
+12
View File
@@ -70,3 +70,15 @@ LOOKAHEAD_BACKTRACKTABLE = 100
# openpi
OPENPI_ATTENTION_MASK_VALUE = -2.3819763e38 # TODO(pepijn): Modify this when extending support to fp8 models
# Constants for LIBERO observation keys
LIBERO_KEY_EEF_POS = "robot_state/eef/pos"
LIBERO_KEY_EEF_QUAT = "robot_state/eef/quat"
LIBERO_KEY_EEF_MAT = "robot_state/eef/mat"
LIBERO_KEY_EEF_AXISANGLE = "robot_state/eef/axisangle"
LIBERO_KEY_GRIPPER_QPOS = "robot_state/gripper/qpos"
LIBERO_KEY_GRIPPER_QVEL = "robot_state/gripper/qvel"
LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
+336
View File
@@ -0,0 +1,336 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test PI0.5 policy with Real-Time Chunking (RTC) enabled during inference."""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_pi05_rtc_initialization():
"""Test PI0.5 policy can initialize RTC processor."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Instantiate policy
policy = PI05Policy(config)
# Verify RTC processor is initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is not None
assert policy.rtc_processor.rtc_config.enabled is True
print("✓ PI0.5 RTC initialization: Test passed")
@require_cuda
def test_pi05_rtc_initialization_without_rtc_config():
"""Test PI0.5 policy can initialize without RTC config."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Instantiate policy
policy = PI05Policy(config)
# Verify RTC processor is not initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
assert policy._rtc_enabled() is False
print("✓ PI0.5 RTC initialization without RTC config: Test passed")
@require_cuda
def test_pi05_rtc_inference_with_prev_chunk():
"""Test PI0.5 policy inference with RTC and previous chunk."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI05Policy(config)
policy.eval()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC and previous chunk
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Test without RTC for comparison
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Verify shapes
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
# With previous chunk, actions should be different (RTC guidance applied)
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
print("✓ PI0.5 RTC inference with prev_chunk: Test passed")
@require_cuda
def test_pi05_rtc_inference_without_prev_chunk():
"""Test PI0.5 policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI05Policy(config)
policy.eval()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC enabled but no previous chunk
actions_with_rtc_no_prev = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=None,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Without previous chunk, RTC should have no effect
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
print("✓ PI0.5 RTC inference without prev_chunk: Test passed")
@require_cuda
def test_pi05_rtc_validation_rules():
"""Test PI0.5 policy with RTC follows all three validation rules."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI05Policy(config)
policy.eval()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
inference_delay = 4
execution_horizon = 10
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
+378
View File
@@ -0,0 +1,378 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test PI0 policy with Real-Time Chunking (RTC) enabled during inference."""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_pi0_rtc_initialization():
"""Test PI0 policy can initialize RTC processor."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Instantiate policy
policy = PI0Policy(config)
# Verify RTC processor is initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is not None
assert policy.rtc_processor.rtc_config.enabled is True
print("✓ PI0 RTC initialization: Test passed")
@require_cuda
def test_pi0_rtc_initialization_without_rtc_config():
"""Test PI0 policy can initialize without RTC config."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Instantiate policy
policy = PI0Policy(config)
# Verify RTC processor is not initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
assert policy._rtc_enabled() is False
print("✓ PI0 RTC initialization without RTC config: Test passed")
def test_pi0_rtc_inference_with_prev_chunk():
"""Test PI0 policy inference with RTC and previous chunk."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC and previous chunk
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Test without RTC for comparison
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Verify shapes
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
# With previous chunk, actions should be different (RTC guidance applied)
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
print("✓ PI0 RTC inference with prev_chunk: Test passed")
@require_cuda
def test_pi0_rtc_inference_without_prev_chunk():
"""Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC enabled but no previous chunk
actions_with_rtc_no_prev = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=None,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Without previous chunk, RTC should have no effect
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
print("✓ PI0 RTC inference without prev_chunk: Test passed")
@require_cuda
def test_pi0_rtc_validation_rules():
"""Test PI0 policy with RTC follows all three validation rules."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
inference_delay = 4
execution_horizon = 10
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
"""Test PI0 with different RTC attention schedules."""
set_seed(42)
schedules = [
RTCAttentionSchedule.ZEROS,
RTCAttentionSchedule.ONES,
RTCAttentionSchedule.LINEAR,
RTCAttentionSchedule.EXP,
]
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
device = config.device
for schedule in schedules:
print(f"Testing schedule: {schedule}")
# Add RTC config with specific schedule
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=schedule,
debug=False,
)
# Instantiate policy
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
actions = policy.predict_action_chunk(
batch,
noise=noise,
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Verify shape
assert actions.shape == (1, config.chunk_size, 7)
print(f" ✓ Schedule {schedule}: Test passed")
print("✓ PI0 RTC different schedules: All schedules tested")
+825
View File
@@ -0,0 +1,825 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC ActionQueue module."""
import threading
import time
import pytest
import torch
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Fixtures ======================
@pytest.fixture
def rtc_config_enabled():
"""Create an RTC config with RTC enabled."""
return RTCConfig(enabled=True, execution_horizon=10, max_guidance_weight=1.0)
@pytest.fixture
def rtc_config_disabled():
"""Create an RTC config with RTC disabled."""
return RTCConfig(enabled=False, execution_horizon=10, max_guidance_weight=1.0)
@pytest.fixture
def sample_actions():
"""Create sample action tensors for testing."""
return {
"original": torch.randn(50, 6), # (time_steps, action_dim)
"processed": torch.randn(50, 6),
"short": torch.randn(10, 6),
"longer": torch.randn(100, 6),
}
@pytest.fixture
def action_queue_rtc_enabled(rtc_config_enabled):
"""Create an ActionQueue with RTC enabled."""
return ActionQueue(rtc_config_enabled)
@pytest.fixture
def action_queue_rtc_disabled(rtc_config_disabled):
"""Create an ActionQueue with RTC disabled."""
return ActionQueue(rtc_config_disabled)
# ====================== Initialization Tests ======================
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
"""Test ActionQueue initializes correctly with RTC enabled."""
queue = ActionQueue(rtc_config_enabled)
assert queue.queue is None
assert queue.original_queue is None
assert queue.last_index == 0
assert queue.cfg.enabled is True
def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
"""Test ActionQueue initializes correctly with RTC disabled."""
queue = ActionQueue(rtc_config_disabled)
assert queue.queue is None
assert queue.original_queue is None
assert queue.last_index == 0
assert queue.cfg.enabled is False
# ====================== get() Tests ======================
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
"""Test get() returns None when queue is empty."""
action = action_queue_rtc_enabled.get()
assert action is None
def test_get_returns_actions_sequentially(action_queue_rtc_enabled, sample_actions):
"""Test get() returns actions in sequence."""
# Initialize queue with actions
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
# Get first action
action1 = action_queue_rtc_enabled.get()
assert action1 is not None
assert action1.shape == (6,)
assert torch.equal(action1, sample_actions["processed"][0])
# Get second action
action2 = action_queue_rtc_enabled.get()
assert action2 is not None
assert torch.equal(action2, sample_actions["processed"][1])
def test_get_returns_none_after_exhaustion(action_queue_rtc_enabled, sample_actions):
"""Test get() returns None after all actions are consumed."""
# Use short action sequence
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all actions
for _ in range(10):
action = action_queue_rtc_enabled.get()
assert action is not None
# Next get should return None
action = action_queue_rtc_enabled.get()
assert action is None
def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
"""Test get() increments last_index correctly."""
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_enabled.last_index == 0
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.last_index == 1
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.last_index == 2
# ====================== qsize() Tests ======================
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
"""Test qsize() returns 0 when queue is empty."""
assert action_queue_rtc_enabled.qsize() == 0
def test_qsize_returns_correct_size(action_queue_rtc_enabled, sample_actions):
"""Test qsize() returns correct number of remaining actions."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.qsize() == 10
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 9
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 8
def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
"""Test qsize() returns 0 after queue is exhausted."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all actions
for _ in range(10):
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 0
# ====================== empty() Tests ======================
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
"""Test empty() returns True when queue is empty."""
assert action_queue_rtc_enabled.empty() is True
def test_empty_returns_false_when_not_empty(action_queue_rtc_enabled, sample_actions):
"""Test empty() returns False when queue has actions."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.empty() is False
def test_empty_after_partial_consumption(action_queue_rtc_enabled, sample_actions):
"""Test empty() returns False after partial consumption."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.empty() is False
def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
"""Test empty() returns True after all actions consumed."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all
for _ in range(10):
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.empty() is True
# ====================== get_action_index() Tests ======================
def test_get_action_index_initial_value(action_queue_rtc_enabled):
"""Test get_action_index() returns 0 initially."""
assert action_queue_rtc_enabled.get_action_index() == 0
def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_actions):
"""Test get_action_index() tracks consumption correctly."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.get_action_index() == 0
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.get_action_index() == 1
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.get_action_index() == 3
# ====================== get_left_over() Tests ======================
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
"""Test get_left_over() returns None when queue is empty."""
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is None
def test_get_left_over_returns_all_when_unconsumed(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() returns all original actions when none consumed."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is not None
assert leftover.shape == (10, 6)
assert torch.equal(leftover, sample_actions["short"])
def test_get_left_over_returns_remaining_after_consumption(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() returns only remaining original actions."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume 3 actions
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is not None
assert leftover.shape == (7, 6)
assert torch.equal(leftover, sample_actions["short"][3:])
def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() returns empty tensor after all consumed."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all
for _ in range(10):
action_queue_rtc_enabled.get()
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is not None
assert leftover.shape == (0, 6)
# ====================== merge() with RTC Enabled Tests ======================
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
"""Test merge() replaces queue when RTC is enabled."""
# Add initial actions
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.qsize() == 10
# Consume some actions
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 8
# Merge new actions - should replace, not append
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
# Queue should be replaced with new actions minus delay
# Original has 50 actions, delay is 5, so remaining is 45
assert action_queue_rtc_enabled.qsize() == 45
assert action_queue_rtc_enabled.get_action_index() == 0
def test_merge_respects_real_delay(action_queue_rtc_enabled, sample_actions):
"""Test merge() correctly applies real_delay when RTC is enabled."""
delay = 10
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=delay)
# Queue should have original length minus delay
expected_size = len(sample_actions["original"]) - delay
assert action_queue_rtc_enabled.qsize() == expected_size
# First action should be the one at index [delay]
first_action = action_queue_rtc_enabled.get()
assert torch.equal(first_action, sample_actions["processed"][delay])
def test_merge_resets_last_index_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
"""Test merge() resets last_index to 0 when RTC is enabled."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.last_index == 2
# Merge new actions
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
assert action_queue_rtc_enabled.last_index == 0
def test_merge_with_zero_delay(action_queue_rtc_enabled, sample_actions):
"""Test merge() with zero delay keeps all actions."""
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_enabled.qsize() == len(sample_actions["original"])
def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
"""Test merge() with delay larger than action sequence."""
# Delay is larger than sequence length
delay = 100
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=delay)
# Queue should be empty (delay >= length)
assert action_queue_rtc_enabled.qsize() == 0
# ====================== merge() with RTC Disabled Tests ======================
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
"""Test merge() appends actions when RTC is disabled."""
# Add initial actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
initial_size = action_queue_rtc_disabled.qsize()
assert initial_size == 10
# Merge more actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Should have appended
assert action_queue_rtc_disabled.qsize() == initial_size + 10
def test_merge_removes_consumed_actions_when_appending(action_queue_rtc_disabled, sample_actions):
"""Test merge() removes consumed actions before appending when RTC is disabled."""
# Add initial actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_disabled.qsize() == 10
# Consume 3 actions
action_queue_rtc_disabled.get()
action_queue_rtc_disabled.get()
action_queue_rtc_disabled.get()
assert action_queue_rtc_disabled.qsize() == 7
# Merge more actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Should have 7 remaining + 10 new = 17
assert action_queue_rtc_disabled.qsize() == 17
def test_merge_resets_last_index_after_append(action_queue_rtc_disabled, sample_actions):
"""Test merge() resets last_index after appending when RTC is disabled."""
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
action_queue_rtc_disabled.get()
action_queue_rtc_disabled.get()
assert action_queue_rtc_disabled.last_index == 2
# Merge more actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# last_index should be reset to 0
assert action_queue_rtc_disabled.last_index == 0
def test_merge_ignores_delay_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
"""Test merge() ignores real_delay parameter when RTC is disabled."""
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=10)
# All actions should be in queue (delay ignored)
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_actions):
"""Test merge() on first call with RTC disabled."""
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
assert action_queue_rtc_disabled.last_index == 0
# ====================== merge() with Different Action Shapes Tests ======================
def test_merge_with_different_action_dims():
"""Test merge() handles actions with different dimensions."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
# Actions with 4 dimensions instead of 6
actions_4d = torch.randn(20, 4)
queue.merge(actions_4d, actions_4d, real_delay=5)
action = queue.get()
assert action.shape == (4,)
def test_merge_with_different_lengths():
"""Test merge() handles action sequences of varying lengths."""
cfg = RTCConfig(enabled=False, execution_horizon=10)
queue = ActionQueue(cfg)
# Add sequences of different lengths
queue.merge(torch.randn(10, 6), torch.randn(10, 6), real_delay=0)
assert queue.qsize() == 10
queue.merge(torch.randn(25, 6), torch.randn(25, 6), real_delay=0)
assert queue.qsize() == 35
# ====================== merge() Delay Validation Tests ======================
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
"""Test merge() validates that real_delay matches action index difference."""
import logging
caplog.set_level(logging.WARNING)
# Initialize queue
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume 5 actions
for _ in range(5):
action_queue_rtc_enabled.get()
# Merge with mismatched delay (should log warning)
# We consumed 5 actions, so index is 5. If we pass action_index_before_inference=0,
# then indexes_diff=5, but if real_delay=3, it will warn
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=3,
action_index_before_inference=0,
)
# Check warning was logged
assert "Indexes diff is not equal to real delay" in caplog.text
def test_merge_no_warning_when_delays_match(action_queue_rtc_enabled, sample_actions, caplog):
"""Test merge() doesn't warn when delays are consistent."""
import logging
caplog.set_level(logging.WARNING)
# Initialize queue
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume 5 actions
for _ in range(5):
action_queue_rtc_enabled.get()
# Merge with matching delay
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=5,
action_index_before_inference=0,
)
# Should not have warning
assert "Indexes diff is not equal to real delay" not in caplog.text
def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled, sample_actions, caplog):
"""Test merge() skips delay validation when action_index_before_inference is None."""
import logging
caplog.set_level(logging.WARNING)
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
for _ in range(5):
action_queue_rtc_enabled.get()
# Pass None for action_index_before_inference
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=999, # Doesn't matter
action_index_before_inference=None,
)
# Should not warn (validation skipped)
assert "Indexes diff is not equal to real delay" not in caplog.text
# ====================== Thread Safety Tests ======================
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
"""Test get() is thread-safe with multiple consumers."""
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
results = []
errors = []
def consumer():
try:
for _ in range(25):
action = action_queue_rtc_enabled.get()
if action is not None:
results.append(action)
time.sleep(0.001)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=consumer) for _ in range(4)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Should have consumed all actions (100 total, 4 threads * 25 each)
assert len(results) == 100
# All results should be unique (no duplicate consumption)
# We can verify by checking that indices are not duplicated
# Since we don't track indices in results, we check total count is correct
assert action_queue_rtc_enabled.qsize() == 0
def test_merge_is_thread_safe(action_queue_rtc_disabled, sample_actions):
"""Test merge() is thread-safe with multiple producers."""
errors = []
def producer():
try:
for _ in range(5):
action_queue_rtc_disabled.merge(
sample_actions["short"], sample_actions["short"], real_delay=0
)
time.sleep(0.001)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=producer) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Should have accumulated all actions (3 threads * 5 merges * 10 actions = 150)
assert action_queue_rtc_disabled.qsize() == 150
def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
"""Test concurrent get() and merge() operations."""
errors = []
consumed_count = [0]
def consumer():
try:
for _ in range(50):
action = action_queue_rtc_disabled.get()
if action is not None:
consumed_count[0] += 1
time.sleep(0.001)
except Exception as e:
errors.append(e)
def producer():
try:
for _ in range(10):
action_queue_rtc_disabled.merge(
sample_actions["short"], sample_actions["short"], real_delay=0
)
time.sleep(0.005)
except Exception as e:
errors.append(e)
consumer_threads = [threading.Thread(target=consumer) for _ in range(2)]
producer_threads = [threading.Thread(target=producer) for _ in range(2)]
for t in consumer_threads + producer_threads:
t.start()
for t in consumer_threads + producer_threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Should have consumed some or all actions (non-deterministic due to timing)
# Total produced: 2 producers * 10 merges * 10 actions = 200
# Total consumed attempts: 2 consumers * 50 = 100
assert consumed_count[0] <= 200
# ====================== get_left_over() Thread Safety Tests ======================
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() is thread-safe with concurrent access."""
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
errors = []
leftovers = []
def reader():
try:
for _ in range(20):
leftover = action_queue_rtc_enabled.get_left_over()
if leftover is not None:
leftovers.append(leftover.shape[0])
time.sleep(0.001)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=reader) for _ in range(3)]
# Also consume some actions concurrently
def consumer():
try:
for _ in range(10):
action_queue_rtc_enabled.get()
time.sleep(0.002)
except Exception as e:
errors.append(e)
consumer_thread = threading.Thread(target=consumer)
all_threads = threads + [consumer_thread]
for t in all_threads:
t.start()
for t in all_threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Leftovers should be monotonically decreasing or stable
# (as actions are consumed, leftover size decreases)
assert len(leftovers) > 0
# ====================== Edge Cases Tests ======================
def test_queue_with_single_action(action_queue_rtc_enabled):
"""Test queue behavior with a single action."""
single_action_original = torch.randn(1, 6)
single_action_processed = torch.randn(1, 6)
action_queue_rtc_enabled.merge(single_action_original, single_action_processed, real_delay=0)
assert action_queue_rtc_enabled.qsize() == 1
action = action_queue_rtc_enabled.get()
assert action is not None
assert action.shape == (6,)
assert action_queue_rtc_enabled.qsize() == 0
def test_queue_behavior_after_multiple_merge_cycles(action_queue_rtc_enabled, sample_actions):
"""Test queue maintains correct state through multiple merge cycles."""
for _ in range(5):
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume half
for _ in range(5):
action_queue_rtc_enabled.get()
# Merge again
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=3)
assert action_queue_rtc_enabled.qsize() > 0
def test_queue_with_all_zeros_actions(action_queue_rtc_enabled):
"""Test queue handles all-zero action tensors."""
zeros_actions = torch.zeros(20, 6)
action_queue_rtc_enabled.merge(zeros_actions, zeros_actions, real_delay=0)
action = action_queue_rtc_enabled.get()
assert torch.all(action == 0)
def test_queue_clones_input_tensors(action_queue_rtc_enabled, sample_actions):
"""Test that merge() clones input tensors, not storing references."""
original_copy = sample_actions["original"].clone()
processed_copy = sample_actions["processed"].clone()
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
# Modify original tensors
sample_actions["original"].fill_(999.0)
sample_actions["processed"].fill_(-999.0)
# Queue should have cloned values
action = action_queue_rtc_enabled.get()
assert not torch.equal(action, sample_actions["processed"][0])
assert torch.equal(action, processed_copy[0])
leftover = action_queue_rtc_enabled.get_left_over()
assert not torch.equal(leftover, sample_actions["original"][1:])
assert torch.equal(leftover, original_copy[1:])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_queue_handles_gpu_tensors():
"""Test queue correctly handles GPU tensors."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
actions_gpu = torch.randn(20, 6, device="cuda")
queue.merge(actions_gpu, actions_gpu, real_delay=0)
action = queue.get()
assert action.device.type == "cuda"
leftover = queue.get_left_over()
assert leftover.device.type == "cuda"
def test_queue_handles_different_dtypes():
"""Test queue handles actions with different dtypes."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
# Use float64 instead of default float32
actions_f64 = torch.randn(20, 6, dtype=torch.float64)
queue.merge(actions_f64, actions_f64, real_delay=0)
action = queue.get()
assert action.dtype == torch.float64
def test_empty_with_none_queue(action_queue_rtc_enabled):
"""Test empty() correctly handles None queue."""
assert action_queue_rtc_enabled.queue is None
assert action_queue_rtc_enabled.empty() is True
def test_qsize_with_none_queue(action_queue_rtc_enabled):
"""Test qsize() correctly handles None queue."""
assert action_queue_rtc_enabled.queue is None
assert action_queue_rtc_enabled.qsize() == 0
# ====================== Integration Tests ======================
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):
"""Test a typical RTC workflow: merge, consume, merge with delay."""
# First inference
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
initial_size = action_queue_rtc_enabled.qsize()
assert initial_size == 50
# Consume 10 actions (execution_horizon)
for _ in range(10):
action = action_queue_rtc_enabled.get()
assert action is not None
assert action_queue_rtc_enabled.qsize() == 40
# Second inference with delay
action_index_before = action_queue_rtc_enabled.get_action_index()
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=5,
action_index_before_inference=action_index_before,
)
# Queue should be replaced, minus delay
assert action_queue_rtc_enabled.qsize() == 45
assert action_queue_rtc_enabled.get_action_index() == 0
def test_typical_non_rtc_workflow(action_queue_rtc_disabled, sample_actions):
"""Test a typical non-RTC workflow: merge, consume, merge again."""
# First inference
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_disabled.qsize() == 50
# Consume 40 actions
for _ in range(40):
action = action_queue_rtc_disabled.get()
assert action is not None
assert action_queue_rtc_disabled.qsize() == 10
# Second inference (should append)
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
# Should have 10 remaining + 50 new = 60
assert action_queue_rtc_disabled.qsize() == 60
@@ -0,0 +1,65 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC configuration module."""
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Initialization Tests ======================
def test_rtc_config_default_initialization():
"""Test RTCConfig initializes with default values."""
config = RTCConfig()
assert config.enabled is False
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
assert config.max_guidance_weight == 10.0
assert config.execution_horizon == 10
assert config.debug is False
assert config.debug_maxlen == 100
def test_rtc_config_custom_initialization():
"""Test RTCConfig initializes with custom values."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
max_guidance_weight=5.0,
execution_horizon=20,
debug=True,
debug_maxlen=200,
)
assert config.enabled is True
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
assert config.max_guidance_weight == 5.0
assert config.execution_horizon == 20
assert config.debug is True
assert config.debug_maxlen == 200
def test_rtc_config_partial_initialization():
"""Test RTCConfig with partial custom values."""
config = RTCConfig(enabled=True, max_guidance_weight=15.0)
assert config.enabled is True
assert config.max_guidance_weight == 15.0
# Other values should be defaults
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
assert config.execution_horizon == 10
assert config.debug is False
+488
View File
@@ -0,0 +1,488 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC debug tracker module."""
import pytest
import torch
from lerobot.policies.rtc.debug_tracker import DebugStep, Tracker
# ====================== Fixtures ======================
@pytest.fixture
def sample_tensors():
"""Create sample tensors for testing."""
return {
"x_t": torch.randn(1, 50, 6),
"v_t": torch.randn(1, 50, 6),
"x1_t": torch.randn(1, 50, 6),
"correction": torch.randn(1, 50, 6),
"err": torch.randn(1, 50, 6),
"weights": torch.randn(1, 50, 1),
}
@pytest.fixture
def enabled_tracker():
"""Create an enabled tracker with default settings."""
return Tracker(enabled=True, maxlen=100)
@pytest.fixture
def disabled_tracker():
"""Create a disabled tracker."""
return Tracker(enabled=False)
# ====================== DebugStep Tests ======================
def test_debug_step_initialization():
"""Test that DebugStep can be initialized with default values."""
step = DebugStep()
assert step.step_idx == 0
assert step.x_t is None
assert step.v_t is None
assert step.x1_t is None
assert step.correction is None
assert step.err is None
assert step.weights is None
assert step.guidance_weight is None
assert step.time is None
assert step.inference_delay is None
assert step.execution_horizon is None
assert step.metadata == {}
def test_debug_step_with_values(sample_tensors):
"""Test DebugStep initialization with actual values."""
step = DebugStep(
step_idx=5,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
x1_t=sample_tensors["x1_t"],
correction=sample_tensors["correction"],
err=sample_tensors["err"],
weights=sample_tensors["weights"],
guidance_weight=2.5,
time=0.8,
inference_delay=4,
execution_horizon=8,
metadata={"custom_key": "custom_value"},
)
assert step.step_idx == 5
assert torch.equal(step.x_t, sample_tensors["x_t"])
assert torch.equal(step.v_t, sample_tensors["v_t"])
assert torch.equal(step.x1_t, sample_tensors["x1_t"])
assert torch.equal(step.correction, sample_tensors["correction"])
assert torch.equal(step.err, sample_tensors["err"])
assert torch.equal(step.weights, sample_tensors["weights"])
assert step.guidance_weight == 2.5
assert step.time == 0.8
assert step.inference_delay == 4
assert step.execution_horizon == 8
assert step.metadata == {"custom_key": "custom_value"}
def test_debug_step_to_dict_without_tensors(sample_tensors):
"""Test converting DebugStep to dictionary without tensor values."""
step = DebugStep(
step_idx=3,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=torch.tensor(3.0),
time=torch.tensor(0.5),
inference_delay=2,
execution_horizon=10,
)
result = step.to_dict(include_tensors=False)
assert result["step_idx"] == 3
assert result["guidance_weight"] == 3.0
assert result["time"] == 0.5
assert result["inference_delay"] == 2
assert result["execution_horizon"] == 10
# Check tensor statistics are included
assert "x_t_stats" in result
assert "v_t_stats" in result
assert "x1_t_stats" not in result # x1_t was None
# Verify statistics structure
assert "shape" in result["x_t_stats"]
assert "mean" in result["x_t_stats"]
assert "std" in result["x_t_stats"]
assert "min" in result["x_t_stats"]
assert "max" in result["x_t_stats"]
# Verify shape matches original tensor
assert result["x_t_stats"]["shape"] == tuple(sample_tensors["x_t"].shape)
def test_debug_step_to_dict_with_tensors(sample_tensors):
"""Test converting DebugStep to dictionary with tensor values."""
step = DebugStep(
step_idx=1,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=1.5,
time=0.9,
)
result = step.to_dict(include_tensors=True)
assert result["step_idx"] == 1
assert result["guidance_weight"] == 1.5
assert result["time"] == 0.9
# Check tensors are included (as CPU tensors)
assert "x_t" in result
assert "v_t" in result
assert isinstance(result["x_t"], torch.Tensor)
assert isinstance(result["v_t"], torch.Tensor)
assert result["x_t"].device.type == "cpu"
assert result["v_t"].device.type == "cpu"
def test_debug_step_to_dict_with_none_guidance_weight():
"""Test to_dict handles None guidance_weight correctly."""
step = DebugStep(step_idx=0, time=1.0, guidance_weight=None)
result = step.to_dict(include_tensors=False)
assert result["guidance_weight"] is None
def test_tracker_initialization_enabled():
"""Test tracker initialization when enabled."""
tracker = Tracker(enabled=True, maxlen=50)
assert tracker.enabled is True
assert tracker._steps == {}
assert tracker._maxlen == 50
assert tracker._step_counter == 0
assert len(tracker) == 0
def test_tracker_reset_when_enabled(enabled_tracker, sample_tensors):
"""Test reset clears all steps when tracker is enabled."""
# Add some steps
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
# Reset
enabled_tracker.reset()
assert len(enabled_tracker) == 0
assert enabled_tracker._step_counter == 0
assert enabled_tracker._steps == {}
def test_tracker_reset_when_disabled(disabled_tracker):
"""Test reset on disabled tracker doesn't cause errors."""
disabled_tracker.reset()
assert len(disabled_tracker) == 0
# ====================== Tracker.track() Tests ======================
def test_track_creates_new_step(enabled_tracker, sample_tensors):
"""Test that track creates a new step when time doesn't exist."""
enabled_tracker.track(
time=1.0,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=5.0,
inference_delay=4,
execution_horizon=8,
)
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert steps[0].step_idx == 0
assert steps[0].time == 1.0
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
assert steps[0].guidance_weight == 5.0
assert steps[0].inference_delay == 4
assert steps[0].execution_horizon == 8
def test_track_updates_existing_step(enabled_tracker, sample_tensors):
"""Test that track updates an existing step at the same time."""
# Create initial step
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert steps[0].v_t is None
# Update the same timestep with v_t
enabled_tracker.track(time=0.9, v_t=sample_tensors["v_t"])
assert len(enabled_tracker) == 1 # Still only one step
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Original x_t preserved
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # New v_t added
def test_track_with_tensor_time(enabled_tracker, sample_tensors):
"""Test track handles tensor time values correctly."""
time_tensor = torch.tensor(0.8)
enabled_tracker.track(time=time_tensor, x_t=sample_tensors["x_t"])
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert abs(steps[0].time - 0.8) < 1e-6 # Use approximate comparison for floating point
def test_track_time_rounding(enabled_tracker, sample_tensors):
"""Test that track rounds time to avoid floating point precision issues."""
# These times should be treated as the same after rounding to 6 decimals
enabled_tracker.track(time=0.9000001, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9000002, v_t=sample_tensors["v_t"])
# Should still be one step (times rounded to same value)
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
def test_track_does_nothing_when_disabled(disabled_tracker, sample_tensors):
"""Test that track does nothing when tracker is disabled."""
disabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
assert len(disabled_tracker) == 0
def test_track_with_metadata(enabled_tracker, sample_tensors):
"""Test track stores custom metadata."""
enabled_tracker.track(time=0.7, x_t=sample_tensors["x_t"], custom_field="custom_value", count=42)
steps = enabled_tracker.get_all_steps()
assert steps[0].metadata["custom_field"] == "custom_value"
assert steps[0].metadata["count"] == 42
def test_track_updates_metadata(enabled_tracker):
"""Test that track updates metadata for existing steps."""
enabled_tracker.track(time=0.6, meta1="value1")
enabled_tracker.track(time=0.6, meta2="value2")
steps = enabled_tracker.get_all_steps()
assert steps[0].metadata["meta1"] == "value1"
assert steps[0].metadata["meta2"] == "value2"
def test_track_clones_tensors(enabled_tracker, sample_tensors):
"""Test that track clones tensors instead of storing references."""
x_t_original = sample_tensors["x_t"].clone()
enabled_tracker.track(time=0.5, x_t=sample_tensors["x_t"])
# Modify original tensor
sample_tensors["x_t"].fill_(999.0)
# Tracked tensor should not be affected
steps = enabled_tracker.get_all_steps()
assert not torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].x_t, x_t_original)
def test_track_with_none_values(enabled_tracker):
"""Test track handles None values correctly."""
enabled_tracker.track(
time=0.4,
x_t=None,
v_t=None,
guidance_weight=None,
inference_delay=None,
)
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert steps[0].x_t is None
assert steps[0].v_t is None
assert steps[0].guidance_weight is None
assert steps[0].inference_delay is None
def test_track_updates_only_non_none_fields(enabled_tracker, sample_tensors):
"""Test that update preserves existing values when None is passed."""
# Create step with x_t
enabled_tracker.track(time=0.3, x_t=sample_tensors["x_t"], guidance_weight=2.0)
# Update with v_t only (pass None for other fields)
enabled_tracker.track(time=0.3, v_t=sample_tensors["v_t"], x_t=None, guidance_weight=None)
# Original values should be preserved
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Still has x_t
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # Now has v_t
assert steps[0].guidance_weight == 2.0 # Still has guidance_weight
# ====================== Tracker.maxlen Tests ======================
def test_tracker_enforces_maxlen():
"""Test that tracker enforces maxlen limit."""
tracker = Tracker(enabled=True, maxlen=3)
# Add 5 steps
for i in range(5):
time = 1.0 - i * 0.1 # 1.0, 0.9, 0.8, 0.7, 0.6
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
# Should only keep the last 3
assert len(tracker) == 3
# Verify oldest steps were removed (should have 0.6, 0.7, 0.8)
steps = tracker.get_all_steps()
times = sorted([step.time for step in steps])
assert times == [0.6, 0.7, 0.8]
def test_tracker_step_idx_increments_despite_maxlen():
"""Test that step_idx continues incrementing even when maxlen is enforced."""
tracker = Tracker(enabled=True, maxlen=2)
# Add 4 steps
for i in range(4):
time = 1.0 - i * 0.1
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
# Should have 2 steps with step_idx 2 and 3 (oldest removed)
steps = sorted(tracker.get_all_steps(), key=lambda s: s.step_idx)
assert len(steps) == 2
assert steps[0].step_idx == 2
assert steps[1].step_idx == 3
def test_tracker_without_maxlen_keeps_all():
"""Test that tracker without maxlen keeps all steps."""
tracker = Tracker(enabled=True, maxlen=None)
# Add 100 steps
for i in range(100):
time = 1.0 - i * 0.01
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
assert len(tracker) == 100
def test_get_all_steps_returns_empty_when_disabled(disabled_tracker):
"""Test get_all_steps returns empty list when disabled."""
steps = disabled_tracker.get_all_steps()
assert steps == []
assert isinstance(steps, list)
def test_get_all_steps_returns_empty_when_no_steps(enabled_tracker):
"""Test get_all_steps returns empty list when no steps tracked."""
steps = enabled_tracker.get_all_steps()
assert steps == []
def test_get_all_steps_returns_all_tracked_steps(enabled_tracker, sample_tensors):
"""Test get_all_steps returns all tracked steps."""
# Track 5 steps
for i in range(5):
time = 1.0 - i * 0.1
enabled_tracker.track(time=time, x_t=sample_tensors["x_t"])
steps = enabled_tracker.get_all_steps()
assert len(steps) == 5
# Verify all are DebugStep instances
for step in steps:
assert isinstance(step, DebugStep)
def test_get_all_steps_preserves_insertion_order(enabled_tracker):
"""Test that get_all_steps preserves insertion order (Python 3.7+)."""
times = [0.9, 0.8, 0.7, 0.6, 0.5]
for time in times:
enabled_tracker.track(time=time, x_t=torch.randn(1, 10, 6))
steps = enabled_tracker.get_all_steps()
retrieved_times = [step.time for step in steps]
# Should be in insertion order
assert retrieved_times == times
# ====================== Tracker.__len__() Tests ======================
def test_len_returns_zero_when_disabled(disabled_tracker):
"""Test __len__ returns 0 when tracker is disabled."""
assert len(disabled_tracker) == 0
def test_len_returns_zero_when_empty(enabled_tracker):
"""Test __len__ returns 0 when no steps are tracked."""
assert len(enabled_tracker) == 0
def test_len_returns_correct_count(enabled_tracker, sample_tensors):
"""Test __len__ returns correct number of tracked steps."""
assert len(enabled_tracker) == 0
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 1
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
enabled_tracker.track(time=0.8, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 3
def test_len_after_reset(enabled_tracker, sample_tensors):
"""Test __len__ returns 0 after reset."""
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
enabled_tracker.reset()
assert len(enabled_tracker) == 0
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_tracker_handles_gpu_tensors():
"""Test tracker correctly handles GPU tensors."""
tracker = Tracker(enabled=True, maxlen=10)
x_t_gpu = torch.randn(1, 50, 6, device="cuda")
tracker.track(time=1.0, x_t=x_t_gpu)
steps = tracker.get_all_steps()
# Tracker should clone and detach tensors
assert steps[0].x_t.device.type == "cuda"
def test_tracker_with_varying_tensor_shapes(enabled_tracker):
"""Test tracker handles varying tensor shapes across steps."""
enabled_tracker.track(time=1.0, x_t=torch.randn(1, 50, 6))
enabled_tracker.track(time=0.9, x_t=torch.randn(1, 25, 6))
enabled_tracker.track(time=0.8, x_t=torch.randn(2, 50, 8))
steps = enabled_tracker.get_all_steps()
assert len(steps) == 3
assert steps[0].x_t.shape == (1, 50, 6)
assert steps[1].x_t.shape == (1, 25, 6)
assert steps[2].x_t.shape == (2, 50, 8)
+322
View File
@@ -0,0 +1,322 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC LatencyTracker module."""
import pytest
from lerobot.policies.rtc.latency_tracker import LatencyTracker
# ====================== Fixtures ======================
@pytest.fixture
def tracker():
"""Create a LatencyTracker with default maxlen."""
return LatencyTracker(maxlen=100)
@pytest.fixture
def small_tracker():
"""Create a LatencyTracker with small maxlen for overflow testing."""
return LatencyTracker(maxlen=5)
# ====================== Initialization Tests ======================
def test_latency_tracker_initialization():
"""Test LatencyTracker initializes correctly."""
tracker = LatencyTracker(maxlen=50)
assert len(tracker) == 0
assert tracker.max_latency == 0.0
assert tracker.max() == 0.0
def test_latency_tracker_default_maxlen():
"""Test LatencyTracker uses default maxlen."""
tracker = LatencyTracker()
# Should accept default maxlen=100
assert len(tracker) == 0
# ====================== add() Tests ======================
def test_add_single_latency(tracker):
"""Test adding a single latency value."""
tracker.add(0.5)
assert len(tracker) == 1
assert tracker.max() == 0.5
def test_add_multiple_latencies(tracker):
"""Test adding multiple latency values."""
latencies = [0.1, 0.5, 0.3, 0.8, 0.2]
for lat in latencies:
tracker.add(lat)
assert len(tracker) == 5
assert tracker.max() == 0.8
def test_add_negative_latency_ignored(tracker):
"""Test that negative latencies are ignored."""
tracker.add(0.5)
tracker.add(-0.1)
tracker.add(0.3)
# Should only have 2 valid latencies
assert len(tracker) == 2
assert tracker.max() == 0.5
def test_add_zero_latency(tracker):
"""Test adding zero latency."""
tracker.add(0.0)
assert len(tracker) == 1
assert tracker.max() == 0.0
def test_add_converts_to_float(tracker):
"""Test add() converts input to float."""
tracker.add(5) # Integer
tracker.add("3.5") # String
assert len(tracker) == 2
assert tracker.max() == 5.0
def test_add_updates_max_latency(tracker):
"""Test that max_latency is updated correctly."""
tracker.add(0.5)
assert tracker.max_latency == 0.5
tracker.add(0.3)
assert tracker.max_latency == 0.5 # Should not decrease
tracker.add(0.9)
assert tracker.max_latency == 0.9 # Should increase
# ====================== reset() Tests ======================
def test_reset_clears_values(tracker):
"""Test reset() clears all values."""
tracker.add(0.5)
tracker.add(0.8)
tracker.add(0.3)
assert len(tracker) == 3
tracker.reset()
assert len(tracker) == 0
assert tracker.max_latency == 0.0
def test_reset_clears_max_latency(tracker):
"""Test reset() resets max_latency."""
tracker.add(1.5)
assert tracker.max_latency == 1.5
tracker.reset()
assert tracker.max_latency == 0.0
def test_reset_allows_new_values(tracker):
"""Test that tracker works correctly after reset."""
tracker.add(0.5)
tracker.reset()
tracker.add(0.3)
assert len(tracker) == 1
assert tracker.max() == 0.3
# ====================== max() Tests ======================
def test_max_returns_zero_when_empty(tracker):
"""Test max() returns 0.0 when tracker is empty."""
assert tracker.max() == 0.0
def test_max_returns_maximum_value(tracker):
"""Test max() returns the maximum latency."""
latencies = [0.2, 0.8, 0.3, 0.5, 0.1]
for lat in latencies:
tracker.add(lat)
assert tracker.max() == 0.8
def test_max_persists_after_sliding_window(small_tracker):
"""Test max() persists even after values slide out of window."""
# Add values that will exceed maxlen=5
small_tracker.add(0.1)
small_tracker.add(0.9) # This is max
small_tracker.add(0.2)
small_tracker.add(0.3)
small_tracker.add(0.4)
small_tracker.add(0.5) # This pushes out 0.1
# Max should still be 0.9 even though only last 5 values kept
assert small_tracker.max() == 0.9
def test_max_after_reset(tracker):
"""Test max() returns 0.0 after reset."""
tracker.add(1.5)
tracker.reset()
assert tracker.max() == 0.0
# ====================== p95() Tests ======================
def test_p95_returns_zero_when_empty(tracker):
"""Test p95() returns 0.0 when tracker is empty."""
assert tracker.p95() == 0.0
def test_p95_returns_95th_percentile(tracker):
"""Test p95() returns the 95th percentile."""
# Add 100 values
for i in range(100):
tracker.add(i / 100.0)
p95 = tracker.p95()
assert 0.93 <= p95 <= 0.96
def test_p95_equals_percentile_95(tracker):
"""Test p95() equals percentile(0.95)."""
for i in range(50):
tracker.add(i / 50.0)
assert tracker.p95() == tracker.percentile(0.95)
# ====================== Edge Cases Tests ======================
def test_single_value(tracker):
"""Test tracker behavior with single value."""
tracker.add(0.75)
assert len(tracker) == 1
assert tracker.max() == 0.75
assert tracker.percentile(0.0) == 0.75
assert tracker.percentile(0.5) == 0.75
assert tracker.percentile(1.0) == 0.75
def test_all_same_values(tracker):
"""Test tracker with all identical values."""
for _ in range(10):
tracker.add(0.5)
assert len(tracker) == 10
assert tracker.max() == 0.5
assert tracker.percentile(0.0) == 0.5
assert tracker.percentile(0.5) == 0.5
assert tracker.percentile(1.0) == 0.5
def test_very_small_values(tracker):
"""Test tracker with very small float values."""
tracker.add(1e-10)
tracker.add(2e-10)
tracker.add(3e-10)
assert len(tracker) == 3
assert tracker.max() == pytest.approx(3e-10)
def test_very_large_values(tracker):
"""Test tracker with very large float values."""
tracker.add(1e10)
tracker.add(2e10)
tracker.add(3e10)
assert len(tracker) == 3
assert tracker.max() == pytest.approx(3e10)
# ====================== Integration Tests ======================
def test_typical_usage_pattern(tracker):
"""Test a typical usage pattern of the tracker."""
# Simulate adding latencies over time
latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10]
for lat in latencies:
tracker.add(lat)
# Check statistics
assert len(tracker) == 10
assert tracker.max() == 0.15
# p95 should be close to max since we have only 10 values
p95 = tracker.p95()
assert p95 >= tracker.percentile(0.5) # p95 should be >= median
assert p95 <= tracker.max() # p95 should be <= max
def test_reset_and_reuse(tracker):
"""Test resetting and reusing tracker."""
# First batch
tracker.add(1.0)
tracker.add(2.0)
assert tracker.max() == 2.0
# Reset
tracker.reset()
# Second batch
tracker.add(0.5)
tracker.add(0.8)
assert len(tracker) == 2
assert tracker.max() == 0.8
assert tracker.percentile(0.5) <= 0.8
# ====================== Type Conversion Tests ======================
def test_add_with_integer(tracker):
"""Test adding integer values."""
tracker.add(5)
assert len(tracker) == 1
assert tracker.max() == 5.0
def test_add_with_string_number(tracker):
"""Test adding string representation of number."""
tracker.add("3.14")
assert len(tracker) == 1
assert tracker.max() == pytest.approx(3.14)
def test_percentile_converts_q_to_float(tracker):
"""Test percentile converts q parameter to float."""
tracker.add(0.5)
tracker.add(0.8)
# Pass integer q
result = tracker.percentile(1)
assert result == 0.8
+773
View File
@@ -0,0 +1,773 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC modeling module (RTCProcessor)."""
import pytest
import torch
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
# ====================== Fixtures ======================
@pytest.fixture
def rtc_config_debug_enabled():
"""Create RTC config with debug enabled."""
return RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=10.0,
execution_horizon=10,
debug=True,
debug_maxlen=100,
)
@pytest.fixture
def rtc_config_debug_disabled():
"""Create RTC config with debug disabled."""
return RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=10.0,
execution_horizon=10,
debug=False,
)
@pytest.fixture
def rtc_processor_debug_enabled(rtc_config_debug_enabled):
"""Create RTCProcessor with debug enabled."""
return RTCProcessor(rtc_config_debug_enabled)
@pytest.fixture
def rtc_processor_debug_disabled(rtc_config_debug_disabled):
"""Create RTCProcessor with debug disabled."""
return RTCProcessor(rtc_config_debug_disabled)
@pytest.fixture
def sample_x_t():
"""Create sample x_t tensor (batch, time, action_dim)."""
return torch.randn(1, 50, 6)
@pytest.fixture
def sample_prev_chunk():
"""Create sample previous chunk tensor."""
return torch.randn(1, 50, 6)
# ====================== Initialization Tests ======================
def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled):
"""Test RTCProcessor initializes with debug tracker."""
processor = RTCProcessor(rtc_config_debug_enabled)
assert processor.rtc_config == rtc_config_debug_enabled
assert processor.tracker is not None
assert processor.tracker.enabled is True
def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled):
"""Test RTCProcessor initializes without debug tracker."""
processor = RTCProcessor(rtc_config_debug_disabled)
assert processor.rtc_config == rtc_config_debug_disabled
assert processor.tracker is None
# ====================== Tracker Proxy Methods Tests ======================
def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test track() forwards to tracker when enabled."""
rtc_processor_debug_enabled.track(
time=torch.tensor(0.5),
x_t=sample_x_t,
v_t=sample_x_t,
guidance_weight=2.0,
)
# Should have tracked one step
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 1
assert steps[0].time == 0.5
def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t):
"""Test track() does nothing when tracker disabled."""
# Should not raise error
rtc_processor_debug_disabled.track(
time=torch.tensor(0.5),
x_t=sample_x_t,
v_t=sample_x_t,
)
# Should return empty list
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert len(steps) == 0
def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test get_all_debug_steps() returns tracked steps."""
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 2
def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled):
"""Test get_all_debug_steps() returns empty list when disabled."""
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert steps == []
assert isinstance(steps, list)
def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled):
"""Test is_debug_enabled() returns True when tracker enabled."""
assert rtc_processor_debug_enabled.is_debug_enabled() is True
def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled):
"""Test is_debug_enabled() returns False when tracker disabled."""
assert rtc_processor_debug_disabled.is_debug_enabled() is False
def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test reset_tracker() clears tracked steps."""
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2
rtc_processor_debug_enabled.reset_tracker()
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0
def test_reset_tracker_when_disabled(rtc_processor_debug_disabled):
"""Test reset_tracker() doesn't error when tracker disabled."""
rtc_processor_debug_disabled.reset_tracker() # Should not raise
# ====================== get_prefix_weights Tests ======================
def test_get_prefix_weights_zeros_schedule():
"""Test get_prefix_weights with ZEROS schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=10, total=20)
# First 5 should be 1.0, rest should be 0.0
assert weights.shape == (20,)
assert torch.all(weights[:5] == 1.0)
assert torch.all(weights[5:] == 0.0)
def test_get_prefix_weights_ones_schedule():
"""Test get_prefix_weights with ONES schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# First 15 should be 1.0, rest should be 0.0
assert weights.shape == (20,)
assert torch.all(weights[:15] == 1.0)
assert torch.all(weights[15:] == 0.0)
def test_get_prefix_weights_linear_schedule():
"""Test get_prefix_weights with LINEAR schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=14, total=25)
# Should have shape (20,)
assert weights.shape == (25,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section (5:15) should be linearly decreasing from 1 to 0
middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
assert torch.allclose(weights[5:14], middle_weights)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[14:] == 0.0)
def test_get_prefix_weights_exp_schedule():
"""Test get_prefix_weights with EXP schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=14, total=25)
# Should have shape (20,)
assert weights.shape == (25,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section should be exponentially weighted
middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061])
assert torch.allclose(weights[5:14], middle_weights, atol=1e-4)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[14:] == 0.0)
def test_get_prefix_weights_with_start_equals_end():
"""Test get_prefix_weights when start equals end."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=20)
# Should have ones up to start, then zeros
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
def test_get_prefix_weights_with_start_greater_than_end():
"""Test get_prefix_weights when start > end (gets clamped)."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
# start > end should use min(start, end) = end
weights = processor.get_prefix_weights(start=15, end=10, total=20)
# Should have ones up to end (10), then zeros
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
# ====================== Helper Method Tests ======================
def test_linweights_with_end_equals_start():
"""Test _linweights when end equals start."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=10, end=10, total=20)
# Should return empty tensor
assert len(weights) == 0
def test_linweights_with_end_less_than_start():
"""Test _linweights when end < start."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=15, end=10, total=20)
# Should return empty tensor
assert len(weights) == 0
def test_add_trailing_zeros_normal():
"""Test _add_trailing_zeros adds zeros correctly."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2])
result = processor._add_trailing_zeros(weights, total=10, end=5)
# Should add 5 zeros (total - end = 10 - 5 = 5)
assert len(result) == 10
assert torch.all(result[:5] == weights)
assert torch.all(result[5:] == 0.0)
def test_add_trailing_zeros_no_zeros_needed():
"""Test _add_trailing_zeros when no zeros needed."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([1.0, 0.8, 0.6])
result = processor._add_trailing_zeros(weights, total=3, end=5)
# zeros_len = 3 - 5 = -2 <= 0, so no zeros added
assert torch.equal(result, weights)
def test_add_leading_ones_normal():
"""Test _add_leading_ones adds ones correctly."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0])
result = processor._add_leading_ones(weights, start=3, total=10)
# Should add 3 ones at the start
assert len(result) == 8
assert torch.all(result[:3] == 1.0)
assert torch.all(result[3:] == weights)
def test_add_leading_ones_no_ones_needed():
"""Test _add_leading_ones when no ones needed."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([0.8, 0.6, 0.4])
result = processor._add_leading_ones(weights, start=0, total=10)
# ones_len = 0, so no ones added
assert torch.equal(result, weights)
def test_get_prefix_weights_with_start_equals_total():
"""Test get_prefix_weights when start equals total."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=20)
# Should have ones up to start, then zeros
assert len(weights) == 20
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
def test_get_prefix_weights_with_total_less_than_start():
"""Test get_prefix_weights when total less than start."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=5)
# Should have ones up to start, then zeros
assert len(weights) == 5
assert torch.all(weights == 1.0)
# ====================== denoise_step Tests ======================
def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled):
"""Test denoise_step without previous chunk (no guidance)."""
x_t = torch.randn(1, 50, 6)
# Mock denoiser that returns fixed velocity
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=None,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should return v_t unchanged (no guidance)
expected = mock_denoiser(x_t)
assert torch.allclose(result, expected)
def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled):
"""Test denoise_step with previous chunk applies guidance."""
x_t = torch.ones(1, 20, 1)
prev_chunk = torch.full((1, 20, 1), 0.1)
def mock_denoiser(x):
return x * 0.5
result = rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
expected_result = torch.tensor(
[
[
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.5833],
[1.3667],
[1.1500],
[0.9333],
[0.7167],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
]
]
)
assert torch.allclose(result, expected_result, atol=1e-4)
def test_denoise_step_adds_batch_dimension():
"""Test denoise_step handles 2D input by adding batch dimension."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
# 2D input (no batch dimension)
x_t = torch.randn(10, 6)
prev_chunk = torch.randn(5, 6)
def mock_denoiser(x):
return x * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Output should be 2D (batch dimension removed)
assert result.ndim == 2
assert result.shape == (10, 6)
def test_denoise_step_uses_custom_execution_horizon():
"""Test denoise_step uses custom execution_horizon parameter."""
config = RTCConfig(execution_horizon=10)
processor = RTCProcessor(config)
x_t = torch.ones(1, 20, 1)
prev_chunk = torch.full((1, 15, 1), 0.1)
def mock_denoiser(x):
return x * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
execution_horizon=15,
)
expected_result = torch.tensor(
[
[
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.6818],
[1.5636],
[1.4455],
[1.3273],
[1.2091],
[1.0909],
[0.9727],
[0.8545],
[0.7364],
[0.6182],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
]
]
)
assert torch.allclose(result, expected_result, atol=1e-4)
def test_denoise_step_guidance_weight_at_time_zero():
"""Test denoise_step handles time=0 (tau=1) without NaN/Inf."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.ones(1, 20, 1)
prev_chunk = torch.full((1, 20, 1), 0.1)
def mock_denoiser(x):
return x * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.0),
original_denoise_step_partial=mock_denoiser,
)
expected_result = torch.tensor(
[
[
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
]
]
)
assert torch.allclose(result, expected_result, atol=1e-4)
def test_denoise_step_with_real_denoise_step_partial():
"""Test denoise_step with a real denoiser."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
batch_size = 10
action_dim = 6
chunk_size = 20
x_t = torch.ones(batch_size, chunk_size, action_dim)
prev_chunk = torch.full((batch_size, chunk_size, action_dim), 0.1)
velocity_function = torch.nn.Sequential(
torch.nn.Linear(action_dim, 1000),
torch.nn.ReLU(),
torch.nn.Linear(1000, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, action_dim),
)
def mock_denoiser(x):
return velocity_function(x)
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
assert result.shape == (batch_size, chunk_size, action_dim)
def test_denoise_step_guidance_weight_at_time_one():
"""Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight)
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(1.0),
original_denoise_step_partial=mock_denoiser,
)
# Should clamp to max_guidance_weight (no Inf)
assert not torch.any(torch.isinf(result))
def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled):
"""Test denoise_step tracks debug information when enabled."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
rtc_processor_debug_enabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should have tracked one step
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 1
# Check tracked values
step = steps[0]
assert step.time == 0.5
assert step.x1_t is not None
assert step.correction is not None
assert step.err is not None
assert step.weights is not None
assert step.guidance_weight is not None
assert step.inference_delay == 5
def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled):
"""Test denoise_step doesn't track when debug disabled."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should not track
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert len(steps) == 0
# ====================== Integration Tests ======================
def test_denoise_step_full_workflow():
"""Test complete denoise_step workflow."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=5.0,
execution_horizon=10,
debug=True,
)
processor = RTCProcessor(config)
# Simulate two denoising steps
x_t1 = torch.randn(1, 50, 6)
x_t2 = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.randn_like(x) * 0.1
# First step - no guidance
result1 = processor.denoise_step(
x_t=x_t1,
prev_chunk_left_over=None,
inference_delay=5,
time=torch.tensor(0.8),
original_denoise_step_partial=mock_denoiser,
)
# Second step - with guidance
result2 = processor.denoise_step(
x_t=x_t2,
prev_chunk_left_over=result1,
inference_delay=5,
time=torch.tensor(0.6),
original_denoise_step_partial=mock_denoiser,
)
# Both should complete successfully
assert result1.shape == (1, 50, 6)
assert result2.shape == (1, 50, 6)
# Should have tracked one step (second one, first had no prev_chunk)
steps = processor.get_all_debug_steps()
assert len(steps) == 1
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_denoise_step_with_cuda_tensors():
"""Test denoise_step works with CUDA tensors."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6, device="cuda")
prev_chunk = torch.randn(1, 50, 6, device="cuda")
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Result should be on CUDA
assert result.device.type == "cuda"
assert result.shape == x_t.shape
def test_denoise_step_deterministic_with_same_inputs():
"""Test denoise_step produces same output with same inputs."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
torch.manual_seed(42)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def deterministic_denoiser(x):
return torch.ones_like(x) * 0.5
result1 = processor.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_chunk.clone(),
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=deterministic_denoiser,
)
result2 = processor.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_chunk.clone(),
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=deterministic_denoiser,
)
# Should produce identical results
assert torch.allclose(result1, result2)
+323
View File
@@ -0,0 +1,323 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test SmolVLA policy with Real-Time Chunking (RTC) enabled during inference."""
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda, require_package # noqa: E402
@require_package("transformers")
@require_cuda
def test_smolvla_rtc_initialization():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy can initialize RTC processor."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Instantiate policy
policy = SmolVLAPolicy(config)
# Verify RTC processor is initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is not None
assert policy.rtc_processor.rtc_config.enabled is True
print("✓ SmolVLA RTC initialization: Test passed")
@require_package("transformers")
@require_cuda
def test_smolvla_rtc_initialization_without_rtc_config():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy can initialize without RTC config."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Instantiate policy
policy = SmolVLAPolicy(config)
# Verify RTC processor is not initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
assert policy._rtc_enabled() is False
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_with_prev_chunk():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy inference with RTC and previous chunk."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and create preprocessor
policy = SmolVLAPolicy(config)
policy.eval()
preprocessor, _ = make_pre_post_processors(
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC and previous chunk
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Test without RTC for comparison
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Verify shapes
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
# With previous chunk, actions should be different (RTC guidance applied)
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_without_prev_chunk():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and create preprocessor
policy = SmolVLAPolicy(config)
policy.eval()
preprocessor, _ = make_pre_post_processors(
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC enabled but no previous chunk
actions_with_rtc_no_prev = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=None,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Without previous chunk, RTC should have no effect
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_validation_rules():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy with RTC follows all three validation rules."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and create preprocessor
policy = SmolVLAPolicy(config)
policy.eval()
preprocessor, _ = make_pre_post_processors(
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
inference_delay = 4
execution_horizon = 10
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
+72
View File
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from lerobot.envs.utils import preprocess_observation
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
seed = 42
np.random.seed(seed)
B = 5
obs1 = {
"pixels": {
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
},
"robot_state": {
"eef": {
"pos": np.random.randn(B, 3),
"quat": np.random.randn(B, 4),
"mat": np.random.randn(B, 3, 3),
},
"gripper": {
"qpos": np.random.randn(B, 2),
"qvel": np.random.randn(B, 2),
},
"joints": {
"pos": np.random.randn(B, 7),
"vel": np.random.randn(B, 7),
},
},
}
observation = preprocess_observation(obs1)
libero_preprocessor = PolicyProcessorPipeline(
steps=[
LiberoProcessorStep(),
]
)
processed_obs = libero_preprocessor(observation)
assert "observation.state" in processed_obs
state = processed_obs["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.dtype == torch.float32
assert state.shape[0] == B
assert state.shape[1] == 8
assert "observation.images.image" in processed_obs
assert "observation.images.image2" in processed_obs
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)