Compare commits

..

65 Commits

Author SHA1 Message Date
Michel Aractingi c868777752 profile 2025-11-18 09:51:50 +01:00
Eugene Mironov 8847e75c55 Extract simulator logic from eval_with real robot and add proper headers to files 2025-11-16 19:04:24 +07:00
Eugene Mironov 8429d2ccfa fixup! fixup! Fixup eval with real robot 2025-11-16 18:35:08 +07:00
Eugene Mironov 6794ca2ba8 fixup! Fixup eval with real robot 2025-11-15 00:09:01 +07:00
Eugene Mironov 98c2152f08 Fixup eval with real robot 2025-11-15 00:09:01 +07:00
Eugene Mironov f92999aeb9 fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization 2025-11-15 00:09:01 +07:00
Eugene Mironov 5659c77988 Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov fd88a3acda Fix SmolVLA init_rtc_processor to use getattr instead of direct model access
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 6deabe4b71 Fix PI0.5 init_rtc_processor to use getattr instead of direct model access
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 2f3525c4a2 Add RTC initialization tests without config for PI0.5 and SmolVLA
Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov d04061def7 fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() 2025-11-15 00:09:01 +07:00
Eugene Mironov 07ee578c78 fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() 2025-11-15 00:09:01 +07:00
Eugene Mironov 636e2264c3 Fix test to use _rtc_enabled() instead of is_rtc_enabled()
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 5a4c168d92 fixup! Add one more test 2025-11-15 00:09:01 +07:00
Eugene Mironov 047f89cc2a Add one more test 2025-11-15 00:09:01 +07:00
Eugene Mironov 4d64733846 fixup! fixup! Add tests for flow matching models with RTC 2025-11-15 00:09:01 +07:00
Eugene Mironov 0c3ed6ca7a fixup! Add tests for flow matching models with RTC 2025-11-15 00:09:01 +07:00
Eugene Mironov 44322fa726 Add tests for flow matching models with RTC 2025-11-15 00:09:01 +07:00
Eugene Mironov e041634bee Add tests for modeling_rtc 2025-11-15 00:09:01 +07:00
Eugene Mironov 6b6c0623cc Fix tests 2025-11-15 00:09:01 +07:00
Eugene Mironov 6db3afca6f Silent validation 2025-11-15 00:09:01 +07:00
Eugene Mironov 433ccc9603 Update README 2025-11-15 00:09:01 +07:00
Eugene Mironov 9e92337f24 Add validatio at the end 2025-11-15 00:09:01 +07:00
Eugene Mironov 99eea2ae03 Add more tests 2025-11-15 00:09:01 +07:00
Eugene Mironov ac33f20e51 Small fixes 2025-11-15 00:09:01 +07:00
Eugene Mironov ab0a9c3d7a Add workable flow 2025-11-15 00:09:01 +07:00
Eugene Mironov 9616c44024 fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 60b432b0f1 fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 513e6c0046 fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 60362b9c7c fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 5915649eac fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 675880392d Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov d0123c4178 fixup! Pi0 eval dataset 2025-11-15 00:09:01 +07:00
Eugene Mironov e86afc883e Pi0 eval dataset 2025-11-15 00:09:01 +07:00
Eugene Mironov d10b7787eb Pi0 2025-11-15 00:09:01 +07:00
Eugene Mironov ac1816ee9c Add RTC to PI0 2025-11-15 00:09:01 +07:00
Eugene Mironov 25fb16ea7a Fix compilation 2025-11-15 00:09:01 +07:00
Eugene Mironov 7baf909e32 Debug 2025-11-15 00:09:01 +07:00
Eugene Mironov 79ffe316e4 Experiemnt with late detach 2025-11-15 00:09:01 +07:00
Eugene Mironov 68b2142bd2 fixup! Add matplotliv to dev 2025-11-15 00:09:01 +07:00
Eugene Mironov a42fb4d0e2 Add matplotliv to dev 2025-11-15 00:09:01 +07:00
Eugene Mironov 83f1de035e delete policies 2025-11-15 00:09:01 +07:00
Eugene Mironov e09a6a90e1 Add torch compilation for eval_dataset 2025-11-15 00:09:01 +07:00
Eugene Mironov 10cc9dd961 Drop not required methods 2025-11-15 00:09:01 +07:00
Eugene Mironov 41b8d4b7c6 Fix tests 2025-11-15 00:09:01 +07:00
Eugene Mironov 7939fc3ddf Add tests for tracker 2025-11-15 00:09:01 +07:00
Eugene Mironov 11b35dfa11 Right kwargs for the policy 2025-11-15 00:09:01 +07:00
Eugene Mironov b27570039c Fix traacking 2025-11-15 00:09:01 +07:00
Eugene Mironov 55c4cc1b27 fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-15 00:09:01 +07:00
Eugene Mironov 3fb3edde3f fixup! fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-15 00:09:01 +07:00
Eugene Mironov 43bf1fb763 fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-15 00:09:01 +07:00
Eugene Mironov c7a26f5070 Improve visualization: separate correction plot and fix axis scaling
Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov aaa308b158 fixup! Refactor plotting loging 2025-11-15 00:09:01 +07:00
Eugene Mironov 84df6cd13d Refactor plotting loging 2025-11-15 00:09:01 +07:00
Eugene Mironov 26db4b64d8 Move plotting logic from modeling_smolvla to eval_dataset script
Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 2204a45020 Refactor SmolVLA plotting to use tracker data instead of local variables
Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov b6df884d08 Fix logging buffering and enable tracking when RTC config provided
- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov bb23dafad1 fixup! Use output_dir for saving all evaluation images 2025-11-15 00:09:01 +07:00
Eugene Mironov c409ed2d1d Use output_dir for saving all evaluation images
Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov d20ef2e46e Rename track_debug method to track
Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 05189361b6 Refactor RTC enabled checks to use _rtc_enabled helper
Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 896779003c Add RTCConfig field to SmolVLAConfig
Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov b55bc62ef0 fixup! Fix rtc_config attribute access in SmolVLA 2025-11-15 00:09:01 +07:00
Eugene Mironov 08ff689a1e Fix rtc_config attribute access in SmolVLA
Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 0acdde4ae2 Add Real-Time Chunking (RTC) support for flow matching models
Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
52 changed files with 3110 additions and 7104 deletions
+2 -8
View File
@@ -15,6 +15,8 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials"
@@ -38,12 +40,6 @@
- local: groot
title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
- local: async
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
title: "Inference"
- sections:
- local: envhub
title: Environments from the Hub
@@ -63,8 +59,6 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101
-418
View File
@@ -1,418 +0,0 @@
# Environment Processors
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
## Why Environment Processors?
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
## The Processing Pipeline
Here's how data flows through the complete processing pipeline during evaluation:
```python
# In lerobot_eval.py rollout() function:
# 1. Raw environment observation (numpy arrays, various formats)
raw_observation = env.step(action)
# 2. Convert numpy to torch, normalize images [0,1]
observation = preprocess_observation(raw_observation)
# 3. Add task metadata (for multi-task environments)
observation = add_envs_task(env, observation)
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
# - Flatten robot states
# - Rotate images to match dataset conventions
# - Handle environment-specific coordinate systems
observation = env_preprocessor(observation)
# 5. POLICY-SPECIFIC preprocessing
# - Normalize with dataset statistics
# - Add batch dimensions
# - Move to GPU
# - Tokenize language instructions
observation = preprocessor(observation)
# 6. Policy inference
action = policy.select_action(observation)
# 7. POLICY-SPECIFIC postprocessing
# - Unnormalize actions
# - Remove batch dimensions
action = postprocessor(action)
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
# - Convert action formats if needed
# - Apply environment-specific constraints
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# 9. Execute in environment
env.step(action)
```
## The Benefits
### 1. **Separation of Concerns**
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
```python
# ❌ Before: Mixed concerns
class LiberoVLAPolicy:
def preprocess(self, obs):
# Environment-specific: Flatten robot state (shouldn't be in policy!)
state = self._flatten_robot_state(obs["robot_state"])
# Policy-specific: Normalize with dataset stats
state = self.normalizer(state)
return state
# ✅ After: Clear separation
# Environment processor: Handles LIBERO's nested robot state
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
# Policy processor: Handles model requirements
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
```
### 2. **Flexibility and Reusability**
The same policy can work with different environment processors, and the same environment processor can work with different policies:
```python
# Use SmolVLA policy with LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
```
### 3. **Easier Experimentation**
Want to try different state representations for LIBERO? Just create a new processor:
```python
# Original: 8D state (pos + quat→axisangle + gripper)
@ProcessorStepRegistry.register("libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
gripper = robot_state["gripper"]["qpos"] # 2D
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
return state
# Experiment: Add velocity for better control
@ProcessorStepRegistry.register("libero_velocity_processor")
class LiberoVelocityProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
# Include velocities for 14D state
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
gripper_pos = robot_state["gripper"]["qpos"] # 2D
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
gripper_pos, gripper_vel], dim=-1) # 14D
return state
```
### 4. **Cleaner Environment Code**
Environments expose **all available data** without needing to know what downstream models will use:
```python
# LIBERO environment exposes full robot state
observation = {
"pixels": {"image": img, "image2": img2},
"robot_state": {
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
"gripper": {"qpos": ..., "qvel": ...},
"joints": {"pos": ..., "vel": ...}
}
}
# Environment processor decides what to use
# Policy processor handles model-specific transformations
```
## Using Environment Processors
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
from lerobot.envs.factory import make_env_pre_post_processors
from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
```
### Implementation in `envs/factory.py`
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:
```python
def eval_main(cfg: EvalPipelineConfig):
# Create environment
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
# Create policy
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
# Create policy processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
)
# Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
# Run evaluation with both processor types
eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor, # Environment-specific
env_postprocessor=env_postprocessor, # Environment-specific
preprocessor=preprocessor, # Policy-specific
postprocessor=postprocessor, # Policy-specific
n_episodes=cfg.eval.n_episodes,
)
```
## Example: LIBERO Environment Processor
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
from lerobot.processor.pipeline import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
**State Processing:**
- Extracts end-effector position (3D)
- Converts quaternion to axis-angle representation (3D)
- Extracts gripper joint positions (2D)
- Concatenates into 8D state vector
**Image Processing:**
- Rotates images 180° to match HuggingFaceVLA/libero convention
"""
def _process_observation(self, observation):
processed_obs = observation.copy()
# Process images: Flip 180° for camera convention
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = processed_obs[key]
img = torch.flip(img, dims=[2, 3]) # Flip H and W
processed_obs[key] = img
# Process robot_state: Flatten to 8D vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
eef_pos = robot_state["eef"]["pos"] # (B, 3)
eef_quat = robot_state["eef"]["quat"] # (B, 4)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
state = state.float()
processed_obs["observation.state"] = state
return processed_obs
```
### Why These Transformations?
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
- Selects the relevant components (pos, quat, gripper)
- Converts quaternion to axis-angle (more suitable for learning)
- Flattens to a single 8D vector that policies expect
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
## Adding Environment Processors for New Environments
To add environment processors for a new environment:
### 1. Create the Processor Step
```python
# In src/lerobot/processor/env_processor.py
@dataclass
@ProcessorStepRegistry.register(name="myenv_processor")
class MyEnvProcessorStep(ObservationProcessorStep):
"""Process observations from MyEnv."""
def _process_observation(self, observation):
processed = observation.copy()
# Your environment-specific transformations
if "myenv.specific.state" in processed:
state = processed.pop("myenv.specific.state")
# Transform to standard format
processed["observation.state"] = self._transform_state(state)
return processed
```
### 2. Update the Factory
```python
# In src/lerobot/envs/factory.py
def make_env_pre_post_processors(env_cfg: EnvConfig):
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### 3. Use in Evaluation
No changes needed! The evaluation script automatically uses the appropriate processor:
```bash
lerobot-eval \
--policy.path=lerobot/my_policy \
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
--eval.n_episodes=10
```
## Future: Environment Postprocessors
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
### Action Space Transformations
```python
@dataclass
class MyEnvActionPostprocessor(ProcessorStep):
"""Convert policy actions to environment-specific format."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Convert from Cartesian to joint space
if self.action_space == "joint":
action = self.ik_solver(action)
# Example: Apply environment-specific safety limits
action = torch.clamp(action, self.min_action, self.max_action)
transition["action"] = action
return transition
```
### Coordinate System Conversions
```python
@dataclass
class CoordinateTransformPostprocessor(ProcessorStep):
"""Transform actions between coordinate systems."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Policy outputs in world frame, env expects base frame
action = self.world_to_base_transform(action)
transition["action"] = action
return transition
```
## Best Practices
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
## Summary
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
- ✅ Enables easy experimentation with different state representations
- ✅ Allows policies to work seamlessly across different environments
- ✅ Keeps environment code focused on simulation/hardware interface
- ✅ Makes processor pipelines more maintainable and debuggable
- ✅ Follows the single responsibility principle
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
-188
View File
@@ -1,188 +0,0 @@
# Real-Time Chunking (RTC)
Real-Time Chunking (RTC) is an inference-time method that allows large, flow-matching based robotic policies, such as [Pi0](./pi0), [Pi0.5](./pi05), and [SmolVLA](./smolvla), to produce smooth, continuous, and reactive motion despite having high inference latency.
These policies generate chunks of future actions (e.g., 50 steps at a time) instead of single actions.
Because the models are large, producing each chunk takes longer than the time it takes the robot to execute it.
Naively executing chunks leads to problems such as pauses, jerky transitions, or sudden changes in strategy whenever the next chunk arrives late or disagrees with the previously executed actions.
RTC solves this by asynchronously generating the next chunk while the robot continues executing the current one, and by guiding the new chunk so it aligns smoothly with the portion of the previous chunk that has already been executed.
## How RTC Works (simplified)
RTC lets the robot think ahead while its still moving. When the robot is carrying out one chunk of actions, RTC starts creating the next chunk early.
But since the robot has already moved a bit by the time the new chunk is ready, RTC has to make sure the new chunk still lines up smoothly with what the robot is currently doing.
To do this, RTC treats the beginning of the new chunk like an inpainting or “fill-in-the-gaps” problem:
it gently adjusts the first part of the new chunk so it blends naturally with the robots ongoing motion. The result is no pauses, no sudden jumps.
In technical terms, RTC adds a guidance term to the flow-matching denoising process that forces the overlapping timesteps of the new chunk to stay close to the executed portion of the previous chunk, typically using a soft transition mask.
## Quick Start
### Installation
RTC is built into LeRobot. Just install the policy dependencies you need:
```bash
# For Pi0 or Pi0.5
pip install -e ".[pi]"
# For SmolVLA
pip install -e ".[smolvla]"
```
### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
from lerobot.policies.pi0 import PI0Policy, PI0Config
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.action_queue import ActionQueue
# Load Pi0 with RTC enabled
policy_cfg = PI0Config()
# Enable RTC
policy_cfg.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10, # How many steps to blend with previous chunk
max_guidance_weight=10.0, # How strongly to enforce consistency
prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
)
# Load the policy
policy = PI0Policy.from_pretrained("lerobot/pi0_base", policy_cfg=policy_cfg, device="cuda")
# Now use predict_action_chunk with RTC parameters
inference_delay = 4 # How many steps of inference latency, this values should be calculated based on the inference latency of the policy
# Initialize the action queue
action_queue = ActionQueue(policy_cfg.rtc_config)
# Start in a separate thread with the following function
def get_actions():
while True:
if should_get_actions:
prev_actions = action_queue.get_left_over()
obs = get_robot_observations(robot)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
action_queue.merge(
actions, actions, inference_delay
)
for step in range(num_steps):
action = action_queue.get()
# Execute the first N actions
execute_actions(action)
```
## Key Parameters
`RTCConfig` has the following parameters to tune:
**`execution_horizon`**: How many timesteps from the previous chunk to maintain consistency with. Higher values mean smoother transitions but potentially less reactivity.
Typical values: 8-12 steps
```python
RTCConfig(execution_horizon=10)
```
**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
**`prefix_attention_schedule`**: How to weight consistency across the overlap region.
- `LINEAR`: Linear decay from inference_delay to execution_horizon
- `EXP`: Exponential decay (recommended for getting started)
- `ONES`: Full weight across entire execution_horizon
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.
## Testing RTC Offline
Before running on a real robot, test RTC with dataset samples to visualize how it works:
```bash
python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \
--device=cuda
```
The script generates a visualization of the denoising process, comparing standard generation (left) with RTC (right). In the RTC plots, you can see how the first few steps (blue/purple lines) are guided to match the red ground truth trajectory (previous chunk's tail), ensuring a smooth transition between chunks.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/flow_matching.png"
alt="Denoising steps with and without RTC"
width="100%"
/>
</p>
## Testing RTC with a Real Robot
```bash
python examples/rtc/eval_with_real_robot.py \
--policy.path=${HF_USERNAME}/policy_repo_id \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120 \
--device=cuda
```
## How It Differs from the Async Inference in LeRobot
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
| Aspect | Async Inference | RTC |
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
**Use both together** for maximum smoothness and reactivity!
## Advanced: Debug Tracking
RTC includes built-in debug tracking to help you understand what's happening during inference:
```python
# Enable debug tracking
policy_cfg.rtc_config.debug = True
policy_cfg.rtc_config.debug_maxlen = 100
# After inference, access debug data
debug_data = policy.rtc_processor.get_debug_data()
# Visualize denoising steps, corrections, etc.
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
## References
- [Smooth-As-Butter Robot Policies](https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html) - Excellent technical explanation with real robot results
- [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/research/real_time_chunking) - Original paper and research
- [Kinetix RTC Implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) - Reference implementation from Physical Intelligence
File diff suppressed because it is too large Load Diff
@@ -1,525 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualize SARM Subtask Annotations
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
For each episode, it shows:
- A timeline with dashed vertical lines at subtask boundaries
- Sample frames from the episode at key points (start, middle, end of each subtask)
- Color-coded subtask segments
Usage:
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
"""
import argparse
import random
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from rich.console import Console
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import load_episodes
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def timestamp_to_seconds(timestamp: str) -> float:
"""Convert MM:SS or SS timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
else:
return int(parts[0])
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
"""
Load annotations from LeRobot dataset parquet files.
Reads subtask annotations from the episodes metadata parquet files.
"""
episodes_dataset = load_episodes(dataset_path)
if episodes_dataset is None or len(episodes_dataset) == 0:
return {}
# Check if subtask columns exist
if "subtask_names" not in episodes_dataset.column_names:
return {}
# Convert to pandas DataFrame for easier access
episodes_df = episodes_dataset.to_pandas()
annotations = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
# Reconstruct SubtaskAnnotation from stored data
subtasks = []
for i, name in enumerate(subtask_names):
# Convert seconds back to MM:SS format
start_sec = int(start_times[i])
end_sec = int(end_times[i])
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
subtasks.append(
Subtask(
name=name,
timestamps=Timestamp(start=start_str, end=end_str)
)
)
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
return annotations
# Color palette for subtasks (colorblind-friendly)
SUBTASK_COLORS = [
"#E69F00", # Orange
"#56B4E9", # Sky blue
"#009E73", # Bluish green
"#F0E442", # Yellow
"#0072B2", # Blue
"#D55E00", # Vermillion
"#CC79A7", # Reddish purple
"#999999", # Gray
]
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
"""Extract a single frame from video at given timestamp."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
# Set position to timestamp
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame = cap.read()
cap.release()
if ret:
# Convert BGR to RGB
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def visualize_episode(
episode_idx: int,
annotation,
video_path: Path,
video_start_timestamp: float,
video_end_timestamp: float,
fps: int,
output_path: Path,
video_key: str,
):
"""
Create visualization for a single episode.
Shows:
- Top row: Sample frames from the episode (one per subtask)
- Bottom: Timeline with subtask segments and boundary lines
"""
subtasks = annotation.subtasks
num_subtasks = len(subtasks)
if num_subtasks == 0:
print(f"No subtasks found for episode {episode_idx}")
return
# Calculate episode duration
episode_duration = video_end_timestamp - video_start_timestamp
# Extract sample frames - get frame from middle of each subtask
sample_frames = []
frame_timestamps = []
for subtask in subtasks:
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
mid_sec = (start_sec + end_sec) / 2
# Convert to video timestamp (add video_start_timestamp offset)
video_timestamp = video_start_timestamp + mid_sec
frame_timestamps.append(mid_sec)
frame = extract_frame_from_video(video_path, video_timestamp)
sample_frames.append(frame)
# Create figure
fig = plt.figure(figsize=(16, 10))
# Use a dark background for better contrast
fig.patch.set_facecolor('#1a1a2e')
# Calculate grid layout
# Top section: frames (variable number of columns based on subtasks)
# Bottom section: timeline
# Create gridspec
gs = fig.add_gridspec(
2, max(num_subtasks, 1),
height_ratios=[2, 1],
hspace=0.3,
wspace=0.1,
left=0.05, right=0.95,
top=0.88, bottom=0.1
)
# Add title
fig.suptitle(
f"Episode {episode_idx} - Subtask Annotations",
fontsize=18,
fontweight='bold',
color='white',
y=0.96
)
# Add subtitle with video info
fig.text(
0.5, 0.91,
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
ha='center',
fontsize=11,
color='#888888'
)
# Plot sample frames
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
ax = fig.add_subplot(gs[0, i])
ax.set_facecolor('#16213e')
if frame is not None:
ax.imshow(frame)
else:
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
fontsize=12, color='white', transform=ax.transAxes)
ax.set_title(
f"{subtask.name}",
fontsize=10,
fontweight='bold',
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
pad=8
)
ax.axis('off')
# Add frame timestamp below
ax.text(
0.5, -0.08,
f"t={frame_timestamps[i]:.1f}s",
ha='center',
fontsize=9,
color='#888888',
transform=ax.transAxes
)
# Create timeline subplot spanning all columns
ax_timeline = fig.add_subplot(gs[1, :])
ax_timeline.set_facecolor('#16213e')
# Get total duration from last subtask end time
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
# Draw subtask segments as colored bars
bar_height = 0.6
bar_y = 0.5
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
# Draw segment bar
rect = mpatches.FancyBboxPatch(
(start_sec, bar_y - bar_height/2),
end_sec - start_sec,
bar_height,
boxstyle="round,pad=0.02,rounding_size=0.1",
facecolor=color,
edgecolor='white',
linewidth=1.5,
alpha=0.85
)
ax_timeline.add_patch(rect)
# Add subtask label inside bar
mid_x = (start_sec + end_sec) / 2
duration = end_sec - start_sec
# Only add text if segment is wide enough
if duration > total_duration * 0.08:
ax_timeline.text(
mid_x, bar_y,
subtask.name,
ha='center', va='center',
fontsize=9,
fontweight='bold',
color='black' if i in [3] else 'white', # Yellow needs dark text
rotation=0 if duration > total_duration * 0.15 else 45
)
# Draw boundary lines (dashed vertical lines between subtasks)
boundary_times = []
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
# Add start boundary (except for first subtask at t=0)
if i == 0 and start_sec > 0:
boundary_times.append(start_sec)
elif i > 0:
boundary_times.append(start_sec)
# Add end boundary for last subtask
if i == len(subtasks) - 1:
boundary_times.append(end_sec)
# Draw dashed lines at boundaries
for t in boundary_times:
ax_timeline.axvline(
x=t,
ymin=0.1, ymax=0.9,
color='white',
linestyle='--',
linewidth=2,
alpha=0.9
)
# Add time label below line
ax_timeline.text(
t, 0.0,
f"{int(t//60):02d}:{int(t%60):02d}",
ha='center', va='top',
fontsize=8,
color='#cccccc'
)
# Add start line at t=0
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
# Configure timeline axes
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
ax_timeline.set_ylim(-0.3, 1.2)
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
ax_timeline.set_ylabel("")
# Style the axes
ax_timeline.spines['top'].set_visible(False)
ax_timeline.spines['right'].set_visible(False)
ax_timeline.spines['left'].set_visible(False)
ax_timeline.spines['bottom'].set_color('#444444')
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
# Add x-axis ticks at regular intervals
tick_interval = max(1, int(total_duration / 10))
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
# Add legend explaining line styles
legend_elements = [
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
]
ax_timeline.legend(
handles=legend_elements,
loc='upper right',
framealpha=0.3,
facecolor='#16213e',
edgecolor='#444444',
fontsize=9,
labelcolor='white'
)
# Save figure
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
plt.close()
return output_path
def main():
parser = argparse.ArgumentParser(
description="Visualize SARM subtask annotations",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID",
)
parser.add_argument(
"--num-episodes",
type=int,
default=5,
help="Number of random episodes to visualize (default: 5)",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Specific episode indices to visualize (overrides --num-episodes)",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Camera/video key to use. If not specified, uses first available.",
)
parser.add_argument(
"--output-dir",
type=str,
default="./subtask_viz",
help="Output directory for visualizations (default: ./subtask_viz)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
args = parser.parse_args()
console = Console()
# Set random seed if specified
if args.seed is not None:
random.seed(args.seed)
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
# Get video key
if args.video_key:
if args.video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
video_key = args.video_key
else:
video_key = dataset.meta.video_keys[0]
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
console.print(f"[cyan]FPS: {fps}[/cyan]")
# Load annotations
console.print(f"\n[cyan]Loading annotations...[/cyan]")
annotations = load_annotations_from_dataset(dataset.root)
if not annotations:
console.print("[red]Error: No annotations found in dataset[/red]")
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
return
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
# Determine which episodes to visualize
if args.episodes:
episode_indices = args.episodes
# Validate episodes exist
for ep in episode_indices:
if ep not in annotations:
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
episode_indices = [ep for ep in episode_indices if ep in annotations]
else:
# Random selection
available_episodes = list(annotations.keys())
num_to_select = min(args.num_episodes, len(available_episodes))
episode_indices = random.sample(available_episodes, num_to_select)
episode_indices.sort()
if not episode_indices:
console.print("[red]Error: No valid episodes to visualize[/red]")
return
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Generate visualizations
for ep_idx in episode_indices:
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
annotation = annotations[ep_idx]
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
console.print(f"[red]Video not found: {video_path}[/red]")
continue
# Get episode-specific timestamps within the video file
video_path_key = f"videos/{video_key}/from_timestamp"
video_path_key_to = f"videos/{video_key}/to_timestamp"
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
# Create visualization
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
try:
visualize_episode(
episode_idx=ep_idx,
annotation=annotation,
video_path=video_path,
video_start_timestamp=video_start_timestamp,
video_end_timestamp=video_end_timestamp,
fps=fps,
output_path=output_path,
video_key=video_key,
)
console.print(f"[green]✓ Saved: {output_path}[/green]")
except Exception as e:
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
# Print summary
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
console.print(f"[bold green]Visualization Complete![/bold green]")
console.print(f"[bold green]{'=' * 50}[/bold green]")
console.print(f"Output directory: {output_dir.absolute()}")
console.print(f"Episodes visualized: {len(episode_indices)}")
if __name__ == "__main__":
main()
@@ -15,12 +15,16 @@
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_droid import DROID_SHARDS
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
class AggregateDatasets(PipelineStep):
@@ -34,11 +38,6 @@ class AggregateDatasets(PipelineStep):
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
+2 -2
View File
@@ -20,7 +20,7 @@ from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_droid import DROID_SHARDS
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_droid import port_droid, validate_dataset
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
+3 -9
View File
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_droid import DROID_SHARDS
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id, private=private),
UploadDataset(repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
@@ -267,12 +267,6 @@ def main():
default="1950M",
help="Memory per cpu that each worker will use.",
)
parser.add_argument(
"--private",
action="store_true",
default=False,
help="Whether to create a private repository.",
)
init_logging()
+263
View File
@@ -0,0 +1,263 @@
# RTC Profiling Guide
This guide explains how to profile RTC (Real-Time Chunking) performance to identify bottlenecks and understand why RTC might be slower than expected.
## Quick Start
### 1. Profile with Real Robot (Profiled Version)
Use `eval_with_real_robot_profiled.py` to profile actual robot execution:
```bash
# With RTC enabled
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=30
# Without RTC for comparison
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=30
```
**Output**: At the end of execution, you'll see a detailed breakdown of timing for each component:
- `get_actions.policy_inference` - Time spent in policy inference
- `get_actions.preprocessing` - Time spent preprocessing observations
- `get_actions.postprocessing` - Time spent postprocessing actions
- `get_actions.action_queue_merge` - Time spent merging actions with RTC
- `robot.get_observation` - Time to get observations from robot
- `robot.send_action` - Time to send actions to robot
- And more...
### 2. Profile Without Robot (Comparison Script)
Use `profile_rtc_comparison.py` to profile just the policy inference without needing a robot:
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
```
**Output**: Side-by-side comparison of performance with and without RTC, including:
- Mean/min/max inference times
- Throughput (iterations per second)
- Verdict on whether RTC is faster or slower
### 3. Enable Detailed Method-Level Profiling
For even more granular profiling, add the `--enable_detailed_profiling` flag:
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20 \
--enable_detailed_profiling
```
This will show timing for individual methods within the policy.
## Understanding the Output
### Key Metrics to Look At
1. **get_actions.policy_inference** - This should be the largest component
- If RTC is enabled, this includes the RTC guidance overhead
- Compare this with/without RTC to see the overhead
2. **get_actions.preprocessing** - Image preprocessing and normalization
- Should be relatively fast
- If slow, consider optimizing image processing
3. **get_actions.postprocessing** - Action denormalization
- Should be minimal
- If slow, check postprocessor implementation
4. **get_actions.action_queue_merge** - RTC-specific merging logic
- Only present when RTC is enabled
- If this is taking significant time, the RTC algorithm may need optimization
5. **robot.get_observation** - Robot communication overhead
- If slow, check camera/sensor latency
- Consider reducing image resolution
6. **robot.send_action** - Action execution overhead
- Should be very fast
- If slow, check robot communication
### Expected Performance
For a typical Pi0 policy on Apple Silicon (MPS):
- **Without RTC**: ~100-200ms per inference
- **With RTC**: Should be similar or slightly faster due to action reuse
- **Preprocessing**: ~5-20ms depending on number of cameras
- **Postprocessing**: ~1-5ms
If RTC is significantly slower, likely causes:
1. **RTC overhead exceeds benefits** - The guidance computation is expensive
2. **Execution horizon too small** - Not reusing enough actions to amortize overhead
3. **No compilation** - Try with `--use_torch_compile`
4. **Large prev_actions buffer** - Copying/processing previous actions is slow
## Profiling Your Own Code
### Using the Profiling Decorator
Add profiling to your own methods:
```python
from lerobot.utils.profiling import profile_method, enable_profiling, print_profiling_summary
# Enable profiling
enable_profiling()
# Decorate methods you want to profile
@profile_method
def my_slow_function(x):
# ... your code ...
return result
# At end of execution
print_profiling_summary()
```
### Using Profile Context Manager
For profiling specific code blocks:
```python
from lerobot.utils.profiling import profile_section, enable_profiling
enable_profiling()
with profile_section("data_loading"):
data = load_data()
with profile_section("model_inference"):
output = model(data)
```
### Adding Profiling to Policy Methods
To profile specific parts of the Pi0 policy, you can add decorators:
```python
# In src/lerobot/policies/pi0/modeling_pi0.py
from lerobot.utils.profiling import profile_method, profile_section
class Pi0Policy:
@profile_method
def predict_action_chunk(self, obs, inference_delay=0, prev_chunk_left_over=None):
# ... existing code ...
pass
def _generate_actions_with_rtc(self, ...):
with profile_section("rtc.guidance_computation"):
# ... guidance code ...
pass
with profile_section("rtc.action_merging"):
# ... merging code ...
pass
```
## Analyzing Results
### Comparison Checklist
When comparing RTC vs non-RTC performance, check:
- [ ] Is `policy_inference` time higher with RTC?
- [ ] Is `action_queue_merge` taking significant time?
- [ ] Are you running enough iterations to amortize warmup?
- [ ] Is torch.compile enabled for fair comparison?
- [ ] Is the execution horizon large enough? (should be >= 10-20)
- [ ] Are you testing on the same hardware/device?
### Common Bottlenecks
1. **Image preprocessing dominates**
- Solution: Reduce image resolution, use fewer cameras, or optimize preprocessing
2. **Action queue operations are slow**
- Solution: Review queue implementation, consider using ring buffer
3. **RTC guidance is expensive**
- Solution: Reduce guidance weight, simplify guidance computation, use torch.compile
4. **Robot communication is slow**
- Solution: Increase baud rate, reduce action frequency, optimize protocol
5. **Memory allocation overhead**
- Solution: Pre-allocate buffers, reuse tensors, avoid unnecessary copies
## Advanced: Adding Custom Metrics
You can add custom timing metrics to the profiled script:
```python
from lerobot.utils.profiling import record_timing
start = time.perf_counter()
# ... your code ...
duration = time.perf_counter() - start
record_timing("my_custom_metric", duration)
```
## Troubleshooting
### Profiling shows RTC is slower by >50%
1. Check if torch.compile is enabled: `--use_torch_compile`
2. Increase execution horizon: `--rtc.execution_horizon=30`
3. Verify inference_delay is calculated correctly
4. Profile with `--enable_detailed_profiling` to find exact bottleneck
### Profiling output is empty
1. Make sure profiling is enabled with `enable_profiling()`
2. Verify you're running enough iterations (at least 10)
3. Check that code is actually executing (not short-circuited)
### Inconsistent results between runs
1. Run more iterations: `--num_iterations=100`
2. Increase warmup iterations
3. Check for thermal throttling on device
4. Ensure no other processes competing for resources
## Next Steps
1. Run both profiling scripts (with/without robot)
2. Compare timing breakdowns
3. Identify the largest bottleneck
4. Focus optimization efforts on that component
5. Re-run profiling to verify improvements
## Questions?
If profiling reveals unexpected bottlenecks or you need help interpreting results, please share:
- The full profiling output
- Your configuration (RTC enabled/disabled, execution horizon, etc.)
- Hardware specs (device type, memory, etc.)
- Policy type and size
+208
View File
@@ -0,0 +1,208 @@
# RTC Profiling - Quick Start
Quick reference for profiling Pi0 with RTC to identify performance bottlenecks.
## 🚀 Quick Commands
### 1. Profile with Real Robot
```bash
# With RTC enabled (profiled version)
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0}, front: {type: opencv, index_or_path: 1}}" \
--task="Pick up object" \
--duration=30
```
### 2. Compare RTC vs No-RTC (No Robot Needed)
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
```
### 3. Detailed RTC Method Profiling
```bash
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=20 \
--execution_horizon=20 \
--enable_rtc_profiling
```
## 📊 What Each Tool Does
| Tool | Purpose | Needs Robot? |
|------|---------|--------------|
| `eval_with_real_robot_profiled.py` | Profile actual robot execution with RTC | ✅ Yes |
| `profile_rtc_comparison.py` | Compare RTC vs no-RTC side-by-side | ❌ No |
| `profile_pi0_rtc_detailed.py` | Deep dive into RTC internals | ❌ No |
## 🔍 Key Metrics to Watch
### Overall Performance
- **iteration.policy_inference** - Total policy inference time
- **iteration.preprocessing** - Image preprocessing time
- **iteration.postprocessing** - Action denormalization time
### RTC-Specific (with `--enable_rtc_profiling`)
- **rtc.denoise_step.base_denoising** - Time without RTC overhead
- **rtc.denoise_step.autograd_correction** - Gradient computation time
- **rtc.denoise_step.guidance_computation** - Total RTC guidance overhead
### Robot Communication
- **robot.get_observation** - Time to get robot state
- **robot.send_action** - Time to send action command
## 🎯 Quick Diagnosis
### RTC is slower than expected?
1. **Check if torch.compile is enabled**
```bash
# Add this flag
--use_torch_compile
```
2. **Try larger execution horizon**
```bash
# Increase to amortize RTC overhead
--rtc.execution_horizon=30
```
3. **Profile to find bottleneck**
```bash
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--enable_rtc_profiling
```
### Preprocessing is slow?
- Reduce image resolution in robot config
- Use fewer cameras
- Check camera FPS settings
### Policy inference is slow?
- Enable torch.compile
- Check device (MPS vs CUDA vs CPU)
- Try smaller model if available
## 📈 Expected Performance
### Typical timings on Apple Silicon (MPS):
| Component | Time (ms) | Notes |
|-----------|-----------|-------|
| Policy inference | 100-200 | Depends on model size |
| Preprocessing | 5-20 | Depends on #cameras |
| Postprocessing | 1-5 | Usually fast |
| RTC overhead | 10-50 | Should be < 50% of base |
### When RTC helps:
- ✅ Execution horizon ≥ 10
- ✅ Inference time > action execution rate
- ✅ Using torch.compile
- ✅ Proper inference_delay calculation
### When RTC might not help:
- ❌ Very fast inference already
- ❌ Small execution horizon (< 5)
- ❌ No compilation (interpreted mode)
- ❌ Inference delay not accounted for
## 🛠️ Adding Profiling to Your Code
### Quick snippet:
```python
from lerobot.utils.profiling import enable_profiling, print_profiling_summary, profile_section
# Enable at start
enable_profiling()
# Profile sections
with profile_section("my_operation"):
# ... your code ...
pass
# Print at end
print_profiling_summary()
```
### Profile specific methods:
```python
from lerobot.utils.profiling import profile_method
@profile_method
def my_slow_function():
# ... your code ...
pass
```
## 📝 Example Output
```
PROFILING SUMMARY
================================================================================
Function Count Mean (ms)
--------------------------------------------------------------------------------
iteration.policy_inference 20 150.23
iteration.preprocessing 20 12.45
rtc.denoise_step.guidance_computation 200 15.67
rtc.denoise_step.autograd_correction 200 8.23
rtc.denoise_step.base_denoising 200 120.45
================================================================================
```
## 🚨 Common Issues
### "No profiling data available"
- Did you call `enable_profiling()`?
- Running enough iterations?
### Inconsistent results
- Increase `--num_iterations`
- Check for thermal throttling
- Close other applications
### Can't find bottleneck
- Enable `--enable_rtc_profiling` for detailed breakdown
- Check both preprocessing and inference
- Compare with and without RTC
## 📖 More Details
See `PROFILING_GUIDE.md` for comprehensive documentation.
## 🤔 Still Slow?
1. Run comparison: `profile_rtc_comparison.py`
2. Run detailed profiling: `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
3. Share output for help (include device, model, settings)
## ✅ Quick Checklist
Before asking for help, verify:
- [ ] Ran comparison script (with/without RTC)
- [ ] Tried torch.compile
- [ ] Tested different execution horizons (10, 20, 30)
- [ ] Profiled with detailed RTC profiling
- [ ] Checked preprocessing vs inference split
- [ ] Verified hardware (device type, thermal state)
+352
View File
@@ -0,0 +1,352 @@
# RTC Profiling Toolkit
Complete toolkit for profiling Pi0 with RTC to identify performance bottlenecks.
## 📦 What's Included
### Scripts
1. **`eval_with_real_robot_profiled.py`**
- Profiled version of the real robot eval script
- Adds timing measurements throughout execution
- Works with actual robot hardware
- Same usage as original but with profiling output
2. **`profile_rtc_comparison.py`**
- Side-by-side comparison of RTC vs no-RTC
- No robot needed (uses mock observations)
- Shows clear verdict on whether RTC is helping
- Great for quick performance checks
3. **`profile_pi0_rtc_detailed.py`**
- Most detailed profiling available
- Can enable RTC method-level profiling
- Provides insights and recommendations
- Perfect for deep-dive investigations
4. **`add_rtc_profiling.py`**
- Monkey-patching utility for RTC internals
- Profiles individual RTC operations
- Can be applied without modifying source
- Shows exactly where RTC spends time
### Utilities
5. **`src/lerobot/utils/profiling.py`**
- Core profiling utilities
- Decorators for method profiling
- Context managers for code blocks
- Statistics collection and reporting
### Documentation
6. **`PROFILING_GUIDE.md`** - Comprehensive guide
7. **`PROFILING_QUICK_START.md`** - Quick reference
## 🚀 Quick Start
### Step 1: Compare Performance
Run this first to see if RTC is actually slower:
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
```
**Expected output:**
```
COMPARISON SUMMARY
================================================================================
Metric Without RTC With RTC Difference
--------------------------------------------------------------------------------
Mean time (ms) 150.23 165.45 +15.22
Throughput (iter/s) 6.66 6.05 -0.61
================================================================================
VERDICT
✗ RTC is SLOWER by 10.1%
Mean time increased by 15.22 ms
Possible reasons:
- RTC overhead exceeds benefits at current execution horizon
- No torch.compile enabled
```
### Step 2: Identify Bottleneck
If RTC is slower, find out why:
```bash
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=20 \
--execution_horizon=20 \
--enable_rtc_profiling
```
**Expected output:**
```
PROFILING SUMMARY
================================================================================
Function Count Mean (ms) Total (s)
------------------------------------------------------------------------------------
iteration.policy_inference 20 150.23 3.00
rtc.denoise_step.guidance_computation 200 15.67 3.13
rtc.denoise_step.autograd_correction 200 8.23 1.65
iteration.preprocessing 20 12.45 0.25
================================================================================
KEY INSIGHTS
================================================================================
Time breakdown:
Policy inference: 150.23 ms (87.2%)
Preprocessing: 12.45 ms (7.2%)
Postprocessing: 2.10 ms (1.2%)
RTC breakdown:
Base denoising: 120.45 ms
Guidance compute: 15.67 ms
Autograd correct: 8.23 ms
RTC overhead: 23.90 ms (19.8% of base)
Recommendations:
⚠ RTC autograd overhead is significant
→ This is expected, but consider increasing execution_horizon
→ Try torch.compile if not already enabled
💡 torch.compile not enabled
→ Try --use_torch_compile for potential speedup
================================================================================
```
### Step 3: Try Optimizations
Based on recommendations:
```bash
# Try with torch.compile
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20 \
--use_torch_compile
# Try larger execution horizon
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=30
```
### Step 4: Profile Real Robot (Optional)
Test with actual hardware:
```bash
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{...}" \
--task="Pick up object" \
--duration=30
```
## 🎯 Common Scenarios
### "RTC is 2x slower!"
This usually means:
- RTC overhead is high but not getting benefits
- Need to enable torch.compile
- Execution horizon too small
- Inference delay not calculated correctly
**Try:**
1. `--use_torch_compile`
2. Increase `--execution_horizon` to 30+
3. Check inference_delay calculation
### "RTC is only slightly slower"
This is expected! RTC overhead is about 10-30% typically.
The benefit comes during **execution**, not single inference:
- Actions are reused across chunks
- Overall system latency is reduced
- Robot gets smoother actions
### "Want to optimize specific part"
Use the profiling utilities:
```python
from lerobot.utils.profiling import enable_profiling, profile_section, print_profiling_summary
enable_profiling()
with profile_section("my_custom_operation"):
# Your code here
pass
print_profiling_summary()
```
## 📊 Understanding Results
### Key Metrics
**Policy Inference Time**
- Time for forward pass through model
- Should be largest component (70-90%)
- Includes RTC guidance if enabled
**Preprocessing Time**
- Image normalization, resizing
- Should be < 20% of total
- If high: reduce image resolution
**RTC Guidance Overhead**
- Extra time for RTC guidance computation
- Typically 10-30% of base inference
- If > 50%: RTC may not be beneficial at current settings
**Autograd Correction**
- Time computing gradients for RTC
- Usually 5-15% of base inference
- Can be reduced with torch.compile
### Expected Ranges (Apple Silicon MPS)
| Metric | Good | Acceptable | Poor |
|--------|------|------------|------|
| Policy inference | 100-150ms | 150-250ms | >250ms |
| Preprocessing | <20ms | 20-50ms | >50ms |
| RTC overhead | 10-30% | 30-50% | >50% |
## 🔧 Optimization Guide
### If RTC overhead is too high:
1. **Enable compilation:**
```bash
--use_torch_compile
```
Expected improvement: 20-40% faster
2. **Increase execution horizon:**
```bash
--execution_horizon=30 # or higher
```
Amortizes RTC cost over more actions
3. **Check guidance weight:**
```python
# In config
rtc.max_guidance_weight=1.0 # try 0.5 for less overhead
```
### If preprocessing is slow:
1. **Reduce image resolution:**
```python
# In robot config
cameras={
"gripper": {"width": 320, "height": 240} # instead of 640x480
}
```
2. **Use fewer cameras:**
- Profile which cameras are essential
- Remove unnecessary views
### If inference is generally slow:
1. Use torch.compile (if not already)
2. Check device is correct (MPS vs CUDA)
3. Verify model is in eval mode
4. Check for unnecessary gradient tracking
## 🐛 Troubleshooting
### Empty profiling output
```python
# Make sure to enable profiling!
from lerobot.utils.profiling import enable_profiling
enable_profiling()
```
### Inconsistent timings
- Run more iterations (50-100)
- Check thermal throttling
- Close background apps
- Use `--warmup_iterations=10`
### Can't find bottleneck
1. Start with `profile_rtc_comparison.py`
2. Then run `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
3. Compare with/without RTC
4. Check each component separately
## 📖 Full Documentation
- **`PROFILING_GUIDE.md`** - Complete reference with examples
- **`PROFILING_QUICK_START.md`** - Quick commands and tips
## 🤝 Getting Help
If you're still experiencing issues:
1. Run comparison script and save output
2. Run detailed profiling and save output
3. Include:
- Policy path
- Device type
- RTC settings (execution_horizon, etc.)
- Hardware specs
- Full profiling output
## 🎓 Learning More
### Profiling your own code:
```python
from lerobot.utils.profiling import profile_method, enable_profiling
enable_profiling()
@profile_method
def my_function():
# Automatically profiled
pass
```
### RTC internals:
```python
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
enable_profiling()
monkey_patch_rtc_profiling()
# Now RTC methods are profiled
policy.predict_action_chunk(...)
```
## ✨ Next Steps
1. Run `profile_rtc_comparison.py` to establish baseline
2. Use `profile_pi0_rtc_detailed.py` to find bottlenecks
3. Apply optimizations (torch.compile, larger horizon)
4. Re-run comparison to verify improvements
5. Test with real robot using profiled version
Happy profiling! 🚀
+251
View File
@@ -0,0 +1,251 @@
# Real-Time Chunking (RTC) Examples
This directory contains examples and evaluation scripts for Real-Time Chunking (RTC), a technique for improving action chunking policies in real-time robot control.
## Overview
Real-Time Chunking addresses the challenge of maintaining consistency and reactivity when using action chunking policies with non-negligible inference latency. It uses a guidance technique during diffusion sampling to blend new action predictions with previously planned actions.
**Key Benefits:**
- Maintains consistency between consecutive action chunks
- Reduces jitter and improves smoothness
- Adapts to inference delays dynamically
**Reference:** [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
## Scripts
### 1. `eval_dataset.py`
Offline evaluation on dataset samples with detailed visualization and validation.
**Features:**
- Compare RTC vs non-RTC predictions on two random dataset samples
- Validate RTC behavior (delay region, blend region, post-horizon region)
- Generate debug visualizations:
- Denoising step comparisons (x_t, v_t, x1_t, corrections)
- Final action predictions comparison
- Support for torch.compile() optimization
- Memory-efficient sequential policy loading for large models
**Usage:**
```bash
# Basic usage with SmolVLA policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--seed=10
# With Pi0.5 policy on CUDA
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# With Pi0 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
```
**Key Parameters:**
- `--policy.path`: Path to pretrained policy
- `--dataset.repo_id`: Dataset to evaluate on
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20)
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0)
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
- `--inference_delay`: Inference delay for RTC (default: 4)
- `--seed`: Random seed for reproducibility (default: 42)
- `--output_dir`: Directory to save visualizations (default: rtc_debug_output)
- `--device`: Device to use (cuda, cpu, mps, auto)
- `--use_torch_compile`: Enable torch.compile() for faster inference
**Output:**
The script generates several visualization files in `rtc_debug_output/`:
- `denoising_xt_comparison.png` - Noisy state evolution during denoising
- `denoising_vt_comparison.png` - Velocity predictions during denoising
- `denoising_x1t_comparison.png` - Predicted final states during denoising
- `denoising_correction_comparison.png` - RTC guidance corrections applied
- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc)
The script also validates RTC behavior and reports:
- ✅ Delay region [0:inference_delay]: RTC = prev_chunk
- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc
- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc
### 2. `eval_with_real_robot.py`
Real-time evaluation on physical robots or simulation environments.
**Features:**
- Run policy with RTC on real robot or simulation
- Multi-threaded action execution and inference
- Action queue management with proper timing
- Latency tracking and adaptive inference delay
- Support for both robots and gym environments
- Support for torch.compile() optimization
**Usage:**
```bash
# With real robot
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--robot.type=so100 \
--task="pick up the cup" \
--duration=30.0
# With simulation environment
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--env.type=pusht \
--duration=60.0
# With policy compilation (CUDA only, not MPS)
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--robot.type=so100 \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
```
**Key Parameters:**
- `--policy.path`: Path to pretrained policy
- `--robot.type` or `--env.type`: Robot or environment to use
- `--task`: Task description (for VLA models)
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
- `--duration`: How long to run (seconds, default: 30.0)
- `--fps`: Action execution frequency (Hz, default: 10.0)
- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30)
- `--device`: Device to use (cuda, cpu, mps, auto)
- `--use_torch_compile`: Enable torch.compile() for faster inference
## Understanding RTC Parameters
### `execution_horizon`
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
**Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution
### `max_guidance_weight`
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
**Typical values:**
- Dataset evaluation: 10.0-100.0 (can be higher for analysis)
- Real-time execution: 1.0-10.0 (more conservative)
### `prefix_attention_schedule`
How to weight consistency across the overlap region:
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
- `ONES`: Full weight across entire execution_horizon
- `LINEAR`: Linear decay from inference_delay to execution_horizon
- `EXP`: Exponential decay (recommended)
**Recommended:** `EXP`
### `inference_delay`
Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation.
**Typical values:** 3-5 steps for dataset evaluation
### `action_queue_size_to_get_new_actions` (real-time only)
Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation.
**Typical values:** 20-30 steps
## Validation Rules (Dataset Evaluation)
The dataset evaluation script validates that RTC behavior matches expectations:
1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk
- Ensures consistency during the inference delay period
2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc
- Smooth transition from previous plan to new predictions
3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc
- Full adoption of new predictions after execution horizon
## Tips
1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot
2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions
3. **Tune execution_horizon** based on your inference latency and action frequency
4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
## Troubleshooting
### Validation fails in delay region
- Check that `prev_chunk_left_over` is properly passed to the policy
- Verify RTC guidance is being applied during denoising
- Look at denoising visualizations to see where guidance diverges
### Validation fails in post-horizon region
- RTC and no_rtc use different noise - verify same noise is being used for comparison
- Check that weights are correctly zeroed out after execution horizon
- Review prefix_attention_schedule visualization
### Poor performance on real robot
- Increase `action_queue_size_to_get_new_actions` if you see warnings
- Reduce `max_guidance_weight` if robot is too conservative
- Try different `prefix_attention_schedule` values
- Enable torch.compile() for faster inference (CUDA only)
### Memory issues with large models
- The dataset evaluation script loads policies sequentially to minimize memory
- For real-time execution, only one policy is loaded
- Use smaller batch sizes if needed
## Related Documentation
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py)
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
+202
View File
@@ -0,0 +1,202 @@
#!/usr/bin/env python
"""
Script to add profiling instrumentation to RTCProcessor.
This script shows which methods to profile in the RTC code to identify bottlenecks.
You can either:
1. Apply these changes directly to modeling_rtc.py
2. Use monkey patching to add profiling without modifying source
3. Use as reference for manual instrumentation
Usage:
# Option 1: Monkey patch (no source changes)
python examples/rtc/add_rtc_profiling.py
# Option 2: Apply changes to source
# Copy the profiled methods below into src/lerobot/policies/rtc/modeling_rtc.py
"""
import logging
import torch
from torch import Tensor
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.profiling import ProfileContext, enable_profiling, is_profiling_enabled
logger = logging.getLogger(__name__)
def profile_denoise_step(self, x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon=None) -> Tensor:
"""Profiled version of denoise_step."""
if not is_profiling_enabled():
# Call original implementation if profiling disabled
return self._original_denoise_step(x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon)
with ProfileContext("rtc.denoise_step.total"):
# In the original implementation, the time goes from 0 to 1 and
# In our implementation, the time goes from 1 to 0
# So we need to invert the time
tau = 1 - time
if prev_chunk_left_over is None:
# First step, no guidance - return v_t
with ProfileContext("rtc.denoise_step.base_denoising"):
v_t = original_denoise_step_partial(x_t)
return v_t
with ProfileContext("rtc.denoise_step.setup"):
x_t = x_t.clone().detach()
squeezed = False
if len(x_t.shape) < 3:
x_t = x_t.unsqueeze(0)
squeezed = True
if len(prev_chunk_left_over.shape) < 3:
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
if execution_horizon is None:
execution_horizon = self.rtc_config.execution_horizon
if execution_horizon > prev_chunk_left_over.shape[1]:
execution_horizon = prev_chunk_left_over.shape[1]
batch_size = x_t.shape[0]
action_chunk_size = x_t.shape[1]
action_dim = x_t.shape[2]
# Padding
with ProfileContext("rtc.denoise_step.padding"):
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
prev_chunk_left_over = padded
# Get prefix weights
with ProfileContext("rtc.denoise_step.get_prefix_weights"):
weights = (
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
.to(x_t.device)
.unsqueeze(0)
.unsqueeze(-1)
)
# Main RTC guidance computation
with ProfileContext("rtc.denoise_step.guidance_computation"):
with torch.enable_grad():
# Base denoising
with ProfileContext("rtc.denoise_step.base_denoising"):
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
# Compute x1_t
with ProfileContext("rtc.denoise_step.compute_x1_t"):
x1_t = x_t - time * v_t
# Compute error
with ProfileContext("rtc.denoise_step.compute_error"):
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
# Compute correction via autograd
with ProfileContext("rtc.denoise_step.autograd_correction"):
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
# Compute guidance weight
with ProfileContext("rtc.denoise_step.compute_guidance_weight"):
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
# Apply guidance
with ProfileContext("rtc.denoise_step.apply_guidance"):
result = v_t - guidance_weight * correction
# Cleanup
with ProfileContext("rtc.denoise_step.cleanup"):
if squeezed:
result = result.squeeze(0)
correction = correction.squeeze(0)
x1_t = x1_t.squeeze(0)
err = err.squeeze(0)
self.track(
time=time,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
return result
def monkey_patch_rtc_profiling():
"""Apply profiling to RTCProcessor via monkey patching.
This modifies the RTCProcessor class at runtime to add profiling
without changing source files.
"""
logger.info("Applying RTC profiling monkey patch...")
# Save original method
RTCProcessor._original_denoise_step = RTCProcessor.denoise_step
# Replace with profiled version
RTCProcessor.denoise_step = profile_denoise_step
logger.info("✓ RTC profiling enabled")
def print_usage():
"""Print usage instructions."""
print("\n" + "="*80)
print("RTC PROFILING INSTRUMENTATION")
print("="*80)
print("\nThis script provides profiling for RTCProcessor methods.")
print("\nOption 1: Monkey Patch (Recommended)")
print("-" * 40)
print("Add to your script:")
print("""
from lerobot.utils.profiling import enable_profiling, print_profiling_summary
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
# Enable profiling
enable_profiling()
monkey_patch_rtc_profiling()
# ... run your code ...
# Print results
print_profiling_summary()
""")
print("\nOption 2: Manual Source Modification")
print("-" * 40)
print("1. Copy profile_denoise_step() from this file")
print("2. Replace denoise_step() in src/lerobot/policies/rtc/modeling_rtc.py")
print("3. Add profiling imports at top of file")
print("\nKey Metrics to Watch:")
print("-" * 40)
print("- rtc.denoise_step.base_denoising - Time for base policy inference")
print("- rtc.denoise_step.autograd_correction - Time computing gradients")
print("- rtc.denoise_step.guidance_computation - Total guidance overhead")
print("- rtc.denoise_step.get_prefix_weights - Time computing weights")
print("="*80 + "\n")
if __name__ == "__main__":
print_usage()
+177 -84
View File
@@ -39,9 +39,8 @@ Usage:
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--rtc.execution_horizon=8 \
--device=mps
--seed=10
# Basic usage with pi0.5 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
@@ -142,7 +141,7 @@ def _check_matplotlib_available():
raise ImportError(
"matplotlib is required for RTC debug visualizations. "
"Please install it by running:\n"
" uv pip install matplotlib"
" uv pip install -e '.[matplotlib-dep]'"
)
@@ -544,6 +543,11 @@ class RTCEvaluator:
logging.info("Plotting results...")
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
# Validate RTC behavior
# logging.info("=" * 80)
# logging.info("Validating RTC behavior...")
# self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
# Plot final actions comparison
logging.info("=" * 80)
logging.info("Plotting final actions comparison...")
@@ -552,6 +556,159 @@ class RTCEvaluator:
logging.info("=" * 80)
logging.info("Evaluation completed successfully")
def validate_rtc_behavior(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
"""Validate RTC behavior by comparing final action predictions with expected values.
Validation rules:
1. During delay [0:inference_delay]: RTC should equal prev_chunk
2. After delay, within execution horizon [inference_delay:execution_horizon]:
RTC should be between prev_chunk and no_rtc
3. After execution horizon [execution_horizon:]: RTC should equal no_rtc
Args:
rtc_actions: Final actions from RTC policy (batch, time, action_dim)
no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim)
prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim)
"""
# Remove batch dimension if present and move to CPU
rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
no_rtc_actions_t = (
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
)
prev_chunk = prev_chunk_left_over.cpu()
logging.info(f" rtc_actions shape: {rtc_actions_t.shape}")
logging.info(f" no_rtc_actions shape: {no_rtc_actions_t.shape}")
logging.info(f" prev_chunk shape: {prev_chunk.shape}")
# Determine chunk length for comparison
chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk.shape[0])
inference_delay = self.cfg.inference_delay
execution_horizon = self.cfg.rtc.execution_horizon
# Tolerance for floating point comparison
rtol = 1e-2 # Relative tolerance
validation_passed = True
warnings = []
logging.info(" Validating RTC behavior:")
logging.info(f" Chunk length: {chunk_len}")
logging.info(f" Inference delay: {inference_delay}")
logging.info(f" Execution horizon: {execution_horizon}")
logging.info(f" Tolerance: rtol={rtol}")
# ============================================================================
# Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk
# ============================================================================
if inference_delay > 0:
delay_end = min(inference_delay, chunk_len)
rtc_delay = rtc_actions_t[:delay_end]
prev_delay = prev_chunk[:delay_end]
logging.info(f" rtc_delay: {rtc_delay.shape}")
logging.info(f" prev_delay: {prev_delay.shape}")
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
mean_diff = torch.mean(torch.abs(rtc_delay - prev_delay)).item()
logging.info(f" rtc_delay: {rtc_delay}")
logging.info(f" prev_delay: {prev_delay}")
logging.info(f" max_diff: {max_diff}")
logging.info(f" mean_diff: {mean_diff}")
warnings.append(
f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], "
f"RTC does NOT equal prev_chunk!\n"
f" Max difference: {max_diff:.6f}\n"
f" Mean difference: {mean_diff:.6f}"
)
validation_passed = False
else:
logging.info(f" ✓ During delay [0:{delay_end}]: RTC equals prev_chunk")
# ============================================================================
# Rule 2: After delay, within execution horizon [inference_delay:execution_horizon]
# RTC should be between prev_chunk and no_rtc
# ============================================================================
blend_start = inference_delay
blend_end = min(execution_horizon, chunk_len)
if blend_end > blend_start:
rtc_blend = rtc_actions_t[blend_start:blend_end]
prev_blend = prev_chunk[blend_start:blend_end]
no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
# Check if RTC is between prev_chunk and no_rtc (element-wise)
# For each element, check if it's between the min and max of prev_chunk and no_rtc
min_bound = torch.minimum(prev_blend, no_rtc_blend)
max_bound = torch.maximum(prev_blend, no_rtc_blend)
within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
if not torch.all(within_bounds):
violations = torch.sum(~within_bounds).item()
total_elements = within_bounds.numel()
violation_pct = 100.0 * violations / total_elements
# Find max violation
lower_violations = torch.maximum(torch.tensor(0.0), min_bound - rtc_blend)
upper_violations = torch.maximum(torch.tensor(0.0), rtc_blend - max_bound)
max_violation = torch.max(torch.maximum(lower_violations, upper_violations)).item()
warnings.append(
f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], "
f"RTC is NOT always between prev_chunk and no_rtc!\n"
f" Violations: {violations}/{total_elements} elements ({violation_pct:.1f}%)\n"
f" Max violation distance: {max_violation:.6f}"
)
validation_passed = False
else:
logging.info(
f" ✓ Blend region [{blend_start}:{blend_end}]: RTC is between prev_chunk and no_rtc"
)
# ============================================================================
# Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc
# ============================================================================
if execution_horizon < chunk_len:
rtc_after = rtc_actions_t[execution_horizon:chunk_len]
no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
logging.info(f" rtc_after: {rtc_after}")
logging.info(f" no_rtc_after: {no_rtc_after}")
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
mean_diff = torch.mean(torch.abs(rtc_after - no_rtc_after)).item()
warnings.append(
f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], "
f"RTC does NOT equal no_rtc!\n"
f" Max difference: {max_diff:.6f}\n"
f" Mean difference: {mean_diff:.6f}"
)
validation_passed = False
else:
logging.info(
f" ✓ After execution horizon [{execution_horizon}:{chunk_len}]: RTC equals no_rtc"
)
# ============================================================================
# Report results
# ============================================================================
logging.info("=" * 80)
if validation_passed:
logging.info(" ✅ VALIDATION PASSED: All RTC behavior checks passed!")
logging.info(" • During delay: RTC = prev_chunk ✓")
logging.info(" • Blend region: prev_chunk ≤ RTC ≤ no_rtc ✓")
logging.info(" • After execution horizon: RTC = no_rtc ✓")
else:
logging.error(" ❌ VALIDATION FAILED: RTC behavior does not match expected!")
logging.error("")
for warning in warnings:
logging.error(warning)
logging.error("")
logging.error(" Please check the implementation of RTC guidance.")
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
"""Plot final action predictions comparison on a single chart.
@@ -638,34 +795,16 @@ class RTCEvaluator:
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
# Add legend only to first subplot
if dim_idx == 0:
ax.legend(loc="best", fontsize=9)
axes[-1].set_xlabel("Step", fontsize=10)
# Collect legend handles and labels from first subplot
handles, labels = axes[0].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right)
fig.legend(
unique_handles,
unique_labels,
loc="center right",
fontsize=9,
bbox_to_anchor=(1.0, 0.5),
framealpha=0.9,
)
# Save figure
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
fig.savefig(output_path, dpi=150, bbox_inches="tight")
fig.tight_layout()
fig.savefig(output_path, dpi=150)
logging.info(f"Saved final actions comparison to {output_path}")
plt.close(fig)
@@ -686,7 +825,6 @@ class RTCEvaluator:
axs_corr[:, 1], # Right column for correction
axs_x1t[:, 1], # Right column for x1_t
num_steps,
add_labels=True, # Add labels for RTC (right column)
)
self._plot_denoising_steps_from_tracker(
@@ -696,7 +834,6 @@ class RTCEvaluator:
axs_corr[:, 0], # Left column for correction
axs_x1t[:, 0], # Left column for x1_t
num_steps,
add_labels=False, # No labels for No RTC (left column)
)
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
@@ -712,21 +849,15 @@ class RTCEvaluator:
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x_t axes (no labels for left column)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Add legends outside the plot area for each figure
self._add_figure_legend(fig_xt, axs_xt)
self._add_figure_legend(fig_vt, axs_vt)
self._add_figure_legend(fig_corr, axs_corr)
self._add_figure_legend(fig_x1t, axs_x1t)
# Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
@@ -744,47 +875,13 @@ class RTCEvaluator:
return fig, axs
def _add_figure_legend(self, fig, axs):
"""Add a legend outside the plot area on the right side.
Args:
fig: Matplotlib figure to add legend to
axs: Array of axes to collect legend handles from
"""
# Collect all handles and labels from the first row of axes (right column)
handles, labels = axs[0, 1].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right, close to charts)
if unique_handles:
fig.legend(
unique_handles,
unique_labels,
loc="center left",
fontsize=8,
bbox_to_anchor=(0.87, 0.5),
framealpha=0.9,
ncol=1,
)
def _save_figure(self, fig, path):
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
fig.savefig(path, dpi=150, bbox_inches="tight")
fig.tight_layout()
fig.savefig(path, dpi=150)
logging.info(f"Saved figure to {path}")
plt.close(fig)
def _plot_denoising_steps_from_tracker(
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
):
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps):
"""Plot denoising steps from tracker data.
Args:
@@ -794,7 +891,6 @@ class RTCEvaluator:
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap
add_labels: Whether to add legend labels for the plots
"""
logging.info("=" * 80)
@@ -809,18 +905,17 @@ class RTCEvaluator:
for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)]
label = f"Step {step_idx}" if add_labels else None
# Plot x_t
if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=label
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
)
# Plot v_t
if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=label
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
)
# Plot correction on separate axes
@@ -830,18 +925,17 @@ class RTCEvaluator:
debug_step.correction,
start_from=0,
color=color,
label=label,
label=f"Step {step_idx}",
)
# Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None:
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
RTCDebugVisualizer.plot_waypoints(
x1t_axs,
debug_step.x1_t,
start_from=0,
color=color,
label=x1t_label,
label=f"x1_t Step {step_idx}",
)
# Plot error in orange dashed
@@ -853,7 +947,6 @@ class RTCEvaluator:
)
num_dims = min(error_chunk.shape[-1], 6)
error_label = f"error Step {step_idx}" if add_labels else None
for j in range(num_dims):
x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
@@ -861,7 +954,7 @@ class RTCEvaluator:
color="orange",
linestyle="--",
alpha=0.7,
label=error_label,
label=f"error Step {step_idx}",
)
# Recalculate axis limits after plotting to ensure proper scaling
@@ -0,0 +1,631 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Profiled version of eval_with_real_robot.py for performance analysis.
This version adds detailed timing measurements for:
- Policy inference
- Preprocessing
- Postprocessing
- Action queue operations
- Robot communication
- Thread execution times
Usage: Same as eval_with_real_robot.py but with profiling output.
"""
import logging
import math
import sys
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ProfileTimer:
"""Context manager and utility class for timing code sections."""
def __init__(self, name: str, stats_dict: dict):
self.name = name
self.stats_dict = stats_dict
self.start_time = None
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, *args):
elapsed = time.perf_counter() - self.start_time
if self.name not in self.stats_dict:
self.stats_dict[self.name] = []
self.stats_dict[self.name].append(elapsed)
class ProfilingStats:
"""Global profiling statistics collector."""
def __init__(self):
self.stats = defaultdict(list)
self.lock = Lock()
def record(self, name: str, duration: float):
with self.lock:
self.stats[name].append(duration)
def timer(self, name: str):
"""Return a context manager for timing."""
return ProfileTimer(name, self.stats)
def get_summary(self) -> dict[str, dict[str, float]]:
"""Get summary statistics for all timings."""
with self.lock:
summary = {}
for name, times in self.stats.items():
if times:
summary[name] = {
"count": len(times),
"mean": sum(times) / len(times),
"min": min(times),
"max": max(times),
"total": sum(times),
}
return summary
def print_summary(self):
"""Print formatted summary of all timings."""
summary = self.get_summary()
logger.info("\n" + "=" * 80)
logger.info("PROFILING SUMMARY")
logger.info("=" * 80)
# Sort by total time (descending)
sorted_items = sorted(summary.items(), key=lambda x: x[1]["total"], reverse=True)
for name, stats in sorted_items:
logger.info(f"\n{name}:")
logger.info(f" Count: {stats['count']}")
logger.info(f" Mean: {stats['mean']*1000:.2f} ms")
logger.info(f" Min: {stats['min']*1000:.2f} ms")
logger.info(f" Max: {stats['max']*1000:.2f} ms")
logger.info(f" Total: {stats['total']:.2f} s")
logger.info(f" Hz: {stats['count']/stats['total']:.2f}")
logger.info("\n" + "=" * 80)
# Global profiling stats
profiling_stats = ProfilingStats()
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with profiling_stats.timer("robot.get_observation"):
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with profiling_stats.timer("robot.send_action"):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy with profiling.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
inference_count = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
with profiling_stats.timer("get_actions.total_iteration"):
inference_count += 1
logger.info(f"[GET_ACTIONS] Starting inference #{inference_count}")
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
with profiling_stats.timer("get_actions.get_prev_actions"):
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
# Get observation
obs = robot.get_observation()
# Apply robot observation processor
with profiling_stats.timer("get_actions.robot_obs_processing"):
obs_processed = robot_observation_processor(obs)
# Build dataset frame
with profiling_stats.timer("get_actions.build_dataset_frame"):
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
# Convert to tensors and normalize
with profiling_stats.timer("get_actions.tensor_conversion"):
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task]
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
# Preprocessing
with profiling_stats.timer("get_actions.preprocessing"):
preproceseded_obs = preprocessor(obs_with_policy_features)
# Policy inference
with profiling_stats.timer("get_actions.policy_inference"):
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Clone for RTC
with profiling_stats.timer("get_actions.clone_actions"):
original_actions = actions.squeeze(0).clone()
# Postprocessing
with profiling_stats.timer("get_actions.postprocessing"):
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
# Update latency tracker
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
logger.info(
f"[GET_ACTIONS] Inference #{inference_count} completed in {new_latency*1000:.2f}ms "
f"(delay={new_delay} chunks)"
)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, "
"It should be higher than inference delay + execution horizon."
)
# Merge into action queue
with profiling_stats.timer("get_actions.action_queue_merge"):
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot with profiling.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
with profiling_stats.timer("actor.total_iteration"):
# Get action from queue
with profiling_stats.timer("actor.queue_get"):
action = action_queue.get()
if action is not None:
# Process action
with profiling_stats.timer("actor.action_processing"):
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
# Send to robot (includes robot.send_action timing)
robot.send_action(action_processed)
action_count += 1
# Sleep to maintain target FPS
dt_s = time.perf_counter() - start_time
sleep_time = max(0, (action_interval - dt_s) - 0.001)
if sleep_time > 0:
time.sleep(sleep_time)
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with profiling."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
logger.info("=" * 80)
logger.info("PROFILING MODE ENABLED")
logger.info("=" * 80)
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processor
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
# Print profiling summary
profiling_stats.print_summary()
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
+358
View File
@@ -0,0 +1,358 @@
#!/usr/bin/env python
"""
Comprehensive profiling script for Pi0 with RTC.
This script demonstrates how to use all the profiling tools to identify
bottlenecks in Pi0 policy inference with RTC enabled.
It profiles:
1. Overall inference time
2. RTC-specific operations (guidance, weights, etc.)
3. Preprocessing/postprocessing
4. Individual method timings
Usage:
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=20 \
--execution_horizon=20 \
--enable_rtc_profiling
"""
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.profiling import (
ProfileContext,
clear_profiling_stats,
enable_profiling,
get_profiling_stats,
print_profiling_summary,
)
# Import monkey patching for RTC profiling
try:
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
except ImportError:
logging.warning("Could not import add_rtc_profiling, detailed RTC profiling disabled")
monkey_patch_rtc_profiling = None
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_mock_observation(policy_config, device: str) -> dict:
"""Create a mock observation matching policy requirements.
Args:
policy_config: Policy configuration
device: Device to create tensors on
Returns:
Mock observation dictionary
"""
obs = {}
# Create mock state observation
state_dim = 10 # Typical robot state dimension
obs["observation.state"] = torch.randn(1, state_dim, device=device)
# Create mock images if needed
# For Pi0, we typically need at least one image
image_height = 224
image_width = 224
# Common image keys for Pi0
image_keys = ["observation.images.gripper", "observation.images.front"]
for key in image_keys:
# Images should be [B, C, H, W] and normalized to [0, 1]
obs[key] = torch.rand(1, 3, image_height, image_width, device=device)
# Add task
obs["task"] = ["Pick up the object"]
# Add language tokens and attention mask (required for Pi0)
# These are mock values - in real usage they come from tokenizer
max_seq_len = 32
obs["observation.language_tokens"] = torch.randint(0, 1000, (1, max_seq_len), device=device)
obs["observation.language_attention_mask"] = torch.ones(1, max_seq_len, device=device)
return obs
def profile_single_iteration(
policy,
preprocessor,
postprocessor,
observation: dict,
prev_actions: torch.Tensor | None,
use_rtc: bool,
inference_delay: int = 0,
) -> tuple[torch.Tensor, torch.Tensor | None, dict]:
"""Profile a single inference iteration.
Args:
policy: Policy instance
preprocessor: Observation preprocessor
postprocessor: Action postprocessor
observation: Input observation
prev_actions: Previous action chunk (for RTC)
use_rtc: Whether RTC is enabled
inference_delay: Inference delay in timesteps
Returns:
Tuple of (actions, new_prev_actions, timings)
"""
timings = {}
with ProfileContext("iteration.total"):
# Preprocessing
with ProfileContext("iteration.preprocessing"):
preprocessed_obs = preprocessor(observation)
# Policy inference
with ProfileContext("iteration.policy_inference"):
if use_rtc:
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
else:
actions = policy.predict_action_chunk(preprocessed_obs)
# Clone for next iteration (if RTC)
new_prev_actions = None
if use_rtc:
with ProfileContext("iteration.prepare_prev_actions"):
execution_horizon = policy.config.rtc_config.execution_horizon
if actions.shape[1] > execution_horizon:
new_prev_actions = actions[:, execution_horizon:].clone()
# Postprocessing
with ProfileContext("iteration.postprocessing"):
processed_actions = postprocessor(actions)
return processed_actions, new_prev_actions, timings
def main():
parser = argparse.ArgumentParser(description="Detailed profiling for Pi0 with RTC")
parser.add_argument("--policy_path", type=str, required=True, help="Path to pretrained policy")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu/mps)")
parser.add_argument("--num_iterations", type=int, default=20, help="Number of iterations")
parser.add_argument("--execution_horizon", type=int, default=10, help="RTC execution horizon")
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
parser.add_argument("--enable_rtc_profiling", action="store_true", help="Enable detailed RTC profiling")
parser.add_argument("--use_torch_compile", action="store_true", help="Use torch.compile")
args = parser.parse_args()
logger.info("="*80)
logger.info("DETAILED PI0 RTC PROFILING")
logger.info("="*80)
logger.info(f"Policy: {args.policy_path}")
logger.info(f"Device: {args.device}")
logger.info(f"Iterations: {args.num_iterations}")
logger.info(f"Execution Horizon: {args.execution_horizon}")
logger.info(f"RTC Profiling: {args.enable_rtc_profiling}")
logger.info("="*80 + "\n")
# Enable profiling
enable_profiling()
# Apply RTC profiling if requested
if args.enable_rtc_profiling:
if monkey_patch_rtc_profiling is not None:
monkey_patch_rtc_profiling()
logger.info("✓ Detailed RTC profiling enabled\n")
else:
logger.warning("⚠ Could not enable detailed RTC profiling\n")
# Load policy
logger.info("Loading policy...")
config = PreTrainedConfig.from_pretrained(args.policy_path)
if hasattr(config, "compile_model"):
config.compile_model = args.use_torch_compile
policy_class = get_policy_class(config.type)
policy = policy_class.from_pretrained(args.policy_path, config=config)
# Configure RTC
policy.config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=args.execution_horizon,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
policy.init_rtc_processor()
policy = policy.to(args.device)
policy.eval()
logger.info(f"✓ Policy loaded: {config.type}\n")
# Create preprocessor and postprocessor
logger.info("Loading preprocessor/postprocessor...")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=config,
pretrained_path=args.policy_path,
dataset_stats=None,
preprocessor_overrides={
"device_processor": {"device": args.device},
},
)
logger.info("✓ Preprocessor/postprocessor loaded\n")
# Create mock observation
logger.info("Creating mock observation...")
observation = create_mock_observation(config, args.device)
logger.info("✓ Mock observation created\n")
# Warmup
logger.info(f"Warming up ({args.warmup_iterations} iterations)...")
prev_actions = None
for i in range(args.warmup_iterations):
with torch.no_grad():
_, prev_actions, _ = profile_single_iteration(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
observation=observation,
prev_actions=prev_actions,
use_rtc=True,
inference_delay=0,
)
# Clear warmup stats
clear_profiling_stats()
logger.info("✓ Warmup complete\n")
# Profiled run WITH RTC
logger.info(f"Running profiled iterations WITH RTC ({args.num_iterations} iterations)...")
prev_actions = None
iteration_times = []
for i in range(args.num_iterations):
start = time.perf_counter()
with torch.no_grad():
_, prev_actions, _ = profile_single_iteration(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
observation=observation,
prev_actions=prev_actions,
use_rtc=True,
inference_delay=0,
)
# Sync CUDA if needed
if args.device.startswith("cuda"):
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
iteration_times.append(elapsed)
if (i + 1) % 5 == 0:
logger.info(f" Completed {i+1}/{args.num_iterations}")
logger.info("✓ Profiling complete\n")
# Print summary statistics
logger.info("\n" + "="*80)
logger.info("ITERATION TIMING SUMMARY")
logger.info("="*80)
times_arr = np.array(iteration_times)
logger.info(f"Mean time: {np.mean(times_arr)*1000:.2f} ms")
logger.info(f"Median time: {np.median(times_arr)*1000:.2f} ms")
logger.info(f"Std dev: {np.std(times_arr)*1000:.2f} ms")
logger.info(f"Min time: {np.min(times_arr)*1000:.2f} ms")
logger.info(f"Max time: {np.max(times_arr)*1000:.2f} ms")
logger.info(f"Total time: {np.sum(times_arr):.2f} s")
logger.info(f"Throughput: {len(times_arr)/np.sum(times_arr):.2f} iter/s")
logger.info("="*80 + "\n")
# Print detailed profiling breakdown
print_profiling_summary(sort_by="total")
# Print key insights
stats = get_profiling_stats()
logger.info("\n" + "="*80)
logger.info("KEY INSIGHTS")
logger.info("="*80)
# Find bottlenecks
if stats:
policy_inference_time = stats.get("iteration.policy_inference", {}).get("mean", 0)
preprocessing_time = stats.get("iteration.preprocessing", {}).get("mean", 0)
postprocessing_time = stats.get("iteration.postprocessing", {}).get("mean", 0)
total_time = policy_inference_time + preprocessing_time + postprocessing_time
if total_time > 0:
logger.info(f"\nTime breakdown:")
logger.info(f" Policy inference: {policy_inference_time*1000:.2f} ms ({policy_inference_time/total_time*100:.1f}%)")
logger.info(f" Preprocessing: {preprocessing_time*1000:.2f} ms ({preprocessing_time/total_time*100:.1f}%)")
logger.info(f" Postprocessing: {postprocessing_time*1000:.2f} ms ({postprocessing_time/total_time*100:.1f}%)")
# RTC-specific insights
if args.enable_rtc_profiling:
rtc_guidance = stats.get("rtc.denoise_step.guidance_computation", {}).get("mean", 0)
rtc_autograd = stats.get("rtc.denoise_step.autograd_correction", {}).get("mean", 0)
rtc_base = stats.get("rtc.denoise_step.base_denoising", {}).get("mean", 0)
if rtc_guidance > 0:
logger.info(f"\nRTC breakdown:")
logger.info(f" Base denoising: {rtc_base*1000:.2f} ms")
logger.info(f" Guidance compute: {rtc_guidance*1000:.2f} ms")
logger.info(f" Autograd correct: {rtc_autograd*1000:.2f} ms")
logger.info(f" RTC overhead: {(rtc_guidance - rtc_base)*1000:.2f} ms")
# Recommendations
logger.info("\nRecommendations:")
if preprocessing_time > policy_inference_time * 0.3:
logger.info(" ⚠ Preprocessing is taking >30% of time")
logger.info(" → Consider reducing image resolution")
logger.info(" → Consider using fewer cameras")
if args.enable_rtc_profiling and rtc_autograd > rtc_base * 0.5:
logger.info(" ⚠ RTC autograd overhead is significant")
logger.info(" → This is expected, but consider increasing execution_horizon")
logger.info(" → Try torch.compile if not already enabled")
if not args.use_torch_compile:
logger.info(" 💡 torch.compile not enabled")
logger.info(" → Try --use_torch_compile for potential speedup")
logger.info("="*80 + "\n")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
logger.info("\n\nProfiling interrupted by user")
sys.exit(0)
except Exception as e:
logger.error(f"\n\nError during profiling: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
+347
View File
@@ -0,0 +1,347 @@
#!/usr/bin/env python
"""
Script to compare performance with and without RTC enabled.
This script helps identify whether RTC is actually improving or degrading performance
by running multiple inference passes and collecting detailed timing statistics.
Usage:
# Profile with mock data (no robot needed)
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50
# Profile with specific RTC config
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
"""
import argparse
import logging
import time
from dataclasses import dataclass
import numpy as np
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.profiling import (
clear_profiling_stats,
enable_profiling,
get_profiling_stats,
print_profiling_summary,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ProfileResults:
"""Results from profiling run."""
mode: str # "with_rtc" or "without_rtc"
mean_time: float
std_time: float
min_time: float
max_time: float
times: list[float]
throughput: float # iterations per second
def create_mock_observation(policy, device: str) -> dict:
"""Create a mock observation for testing.
Args:
policy: Policy instance
device: Device to create tensors on
Returns:
Mock observation dictionary
"""
# Get expected input shapes from policy config
# This is a simplified version - adjust based on actual policy requirements
obs = {}
# Mock image observations (if needed)
if hasattr(policy.config, "input_shapes"):
for key, shape in policy.config.input_shapes.items():
if "image" in key:
# Typical image shape: (batch, channels, height, width)
obs[key] = torch.randn(1, *shape, device=device)
else:
obs[key] = torch.randn(1, *shape, device=device)
# Add task if needed
if "task" in policy.config.__dict__ or hasattr(policy, "accepts_task"):
obs["task"] = ["Pick up the object"]
# Mock state observation
obs["observation.state"] = torch.randn(1, 10, device=device) # Adjust size as needed
return obs
def profile_inference(
policy, observation: dict, num_iterations: int, use_rtc: bool, execution_horizon: int = 10
) -> ProfileResults:
"""Profile policy inference with or without RTC.
Args:
policy: Policy instance
observation: Observation dictionary
num_iterations: Number of inference iterations to run
use_rtc: Whether to enable RTC
execution_horizon: Execution horizon for RTC
Returns:
ProfileResults with timing statistics
"""
mode = "with_rtc" if use_rtc else "without_rtc"
logger.info(f"\n{'='*80}")
logger.info(f"Profiling: {mode.upper()}")
logger.info(f"{'='*80}")
# Configure RTC
if use_rtc:
policy.config.rtc_config.enabled = True
policy.config.rtc_config.execution_horizon = execution_horizon
policy.init_rtc_processor()
else:
policy.config.rtc_config.enabled = False
times = []
prev_actions = None
# Warmup
logger.info("Warming up (5 iterations)...")
for _ in range(5):
with torch.no_grad():
if use_rtc:
_ = policy.predict_action_chunk(
observation, inference_delay=0, prev_chunk_left_over=prev_actions
)
else:
_ = policy.predict_action_chunk(observation)
# Actual profiling
logger.info(f"Running {num_iterations} profiled iterations...")
for i in range(num_iterations):
start = time.perf_counter()
with torch.no_grad():
if use_rtc:
actions = policy.predict_action_chunk(
observation, inference_delay=0, prev_chunk_left_over=prev_actions
)
# Simulate consuming some actions for next iteration
if actions.shape[1] > execution_horizon:
prev_actions = actions[:, execution_horizon:].clone()
else:
prev_actions = None
else:
actions = policy.predict_action_chunk(observation)
# Synchronize if using CUDA
if observation["observation.state"].device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
times.append(elapsed)
if (i + 1) % 10 == 0:
logger.info(f" Completed {i+1}/{num_iterations} iterations")
# Calculate statistics
times_arr = np.array(times)
results = ProfileResults(
mode=mode,
mean_time=float(np.mean(times_arr)),
std_time=float(np.std(times_arr)),
min_time=float(np.min(times_arr)),
max_time=float(np.max(times_arr)),
times=times,
throughput=num_iterations / sum(times),
)
logger.info(f"\nResults for {mode}:")
logger.info(f" Mean time: {results.mean_time*1000:.2f} ms")
logger.info(f" Std dev: {results.std_time*1000:.2f} ms")
logger.info(f" Min time: {results.min_time*1000:.2f} ms")
logger.info(f" Max time: {results.max_time*1000:.2f} ms")
logger.info(f" Throughput: {results.throughput:.2f} iter/s")
return results
def compare_results(results_without_rtc: ProfileResults, results_with_rtc: ProfileResults):
"""Compare and print results from both runs.
Args:
results_without_rtc: Results from run without RTC
results_with_rtc: Results from run with RTC
"""
logger.info(f"\n{'='*80}")
logger.info("COMPARISON SUMMARY")
logger.info(f"{'='*80}")
mean_diff = results_with_rtc.mean_time - results_without_rtc.mean_time
mean_diff_pct = (mean_diff / results_without_rtc.mean_time) * 100
throughput_diff = results_with_rtc.throughput - results_without_rtc.throughput
throughput_diff_pct = (throughput_diff / results_without_rtc.throughput) * 100
logger.info(f"\n{'Metric':<30} {'Without RTC':>15} {'With RTC':>15} {'Difference':>15}")
logger.info("-" * 80)
logger.info(
f"{'Mean time (ms)':<30} "
f"{results_without_rtc.mean_time*1000:>15.2f} "
f"{results_with_rtc.mean_time*1000:>15.2f} "
f"{mean_diff*1000:>+15.2f}"
)
logger.info(
f"{'Std dev (ms)':<30} "
f"{results_without_rtc.std_time*1000:>15.2f} "
f"{results_with_rtc.std_time*1000:>15.2f} "
f"{(results_with_rtc.std_time - results_without_rtc.std_time)*1000:>+15.2f}"
)
logger.info(
f"{'Min time (ms)':<30} "
f"{results_without_rtc.min_time*1000:>15.2f} "
f"{results_with_rtc.min_time*1000:>15.2f} "
f"{(results_with_rtc.min_time - results_without_rtc.min_time)*1000:>+15.2f}"
)
logger.info(
f"{'Max time (ms)':<30} "
f"{results_without_rtc.max_time*1000:>15.2f} "
f"{results_with_rtc.max_time*1000:>15.2f} "
f"{(results_with_rtc.max_time - results_without_rtc.max_time)*1000:>+15.2f}"
)
logger.info(
f"{'Throughput (iter/s)':<30} "
f"{results_without_rtc.throughput:>15.2f} "
f"{results_with_rtc.throughput:>15.2f} "
f"{throughput_diff:>+15.2f}"
)
logger.info(f"\n{'='*80}")
logger.info("VERDICT")
logger.info(f"{'='*80}")
if mean_diff_pct < -5:
logger.info(f"✓ RTC is FASTER by {abs(mean_diff_pct):.1f}%")
logger.info(f" Mean time reduced by {abs(mean_diff)*1000:.2f} ms")
elif mean_diff_pct > 5:
logger.info(f"✗ RTC is SLOWER by {mean_diff_pct:.1f}%")
logger.info(f" Mean time increased by {mean_diff*1000:.2f} ms")
logger.info("\n Possible reasons:")
logger.info(" - RTC overhead exceeds benefits at current execution horizon")
logger.info(" - Inference delay calculation not accounting for RTC processing")
logger.info(" - Additional tensor operations in RTC guidance")
else:
logger.info(f"≈ Performance is SIMILAR (difference: {mean_diff_pct:+.1f}%)")
logger.info(f"{'='*80}\n")
def main():
parser = argparse.ArgumentParser(description="Profile RTC performance")
parser.add_argument(
"--policy_path", type=str, required=True, help="Path to pretrained policy"
)
parser.add_argument(
"--device", type=str, default="cuda", help="Device to run on (cuda/cpu/mps)"
)
parser.add_argument(
"--num_iterations", type=int, default=50, help="Number of inference iterations"
)
parser.add_argument(
"--execution_horizon", type=int, default=10, help="RTC execution horizon"
)
parser.add_argument(
"--enable_detailed_profiling",
action="store_true",
help="Enable detailed method-level profiling",
)
parser.add_argument(
"--use_torch_compile", action="store_true", help="Use torch.compile for faster inference"
)
args = parser.parse_args()
# Load policy
logger.info(f"Loading policy from {args.policy_path}")
config = PreTrainedConfig.from_pretrained(args.policy_path)
policy_class = get_policy_class(config.type)
# Set compile flag if needed
if hasattr(config, "compile_model"):
config.compile_model = args.use_torch_compile
policy = policy_class.from_pretrained(args.policy_path, config=config)
# Initialize RTC config
policy.config.rtc_config = RTCConfig(
execution_horizon=args.execution_horizon,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
policy = policy.to(args.device)
policy.eval()
logger.info(f"Policy loaded: {config.type}")
logger.info(f"Device: {args.device}")
logger.info(f"Execution horizon: {args.execution_horizon}")
# Create mock observation
logger.info("Creating mock observation...")
observation = create_mock_observation(policy, args.device)
# Enable detailed profiling if requested
if args.enable_detailed_profiling:
enable_profiling()
logger.info("Detailed profiling enabled")
# Profile without RTC
results_without_rtc = profile_inference(
policy=policy,
observation=observation,
num_iterations=args.num_iterations,
use_rtc=False,
execution_horizon=args.execution_horizon,
)
if args.enable_detailed_profiling:
logger.info("\nDetailed profiling stats (WITHOUT RTC):")
print_profiling_summary()
clear_profiling_stats()
# Profile with RTC
results_with_rtc = profile_inference(
policy=policy,
observation=observation,
num_iterations=args.num_iterations,
use_rtc=True,
execution_horizon=args.execution_horizon,
)
if args.enable_detailed_profiling:
logger.info("\nDetailed profiling stats (WITH RTC):")
print_profiling_summary()
# Compare results
compare_results(results_without_rtc, results_with_rtc)
if __name__ == "__main__":
main()
+2 -1
View File
@@ -98,6 +98,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -132,7 +133,7 @@ groot = [
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
-761
View File
@@ -1,761 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for SARM (Stage-Aware Reward Model).
This script loads a trained SARM model and runs inference on a dataset episode,
generating visualizations of the predicted task stages and progress over time.
Example usage:
python scripts/visualize_sarm_predictions.py \
--model-id username/sarm-model \
--dataset-repo lerobot/aloha_sim_insertion_human \
--episode-index 0 \
--output-dir outputs/sarm_viz \
--task-description "insert the peg into the socket"
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
from lerobot.policies.sarm.sarm_utils import (
pad_state_to_max_dim,
compute_tau,
compute_cumulative_progress_batch,
)
from lerobot.datasets.utils import load_stats
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
# Model arguments
parser.add_argument(
"--model-id",
type=str,
required=True,
help="HuggingFace model ID or local path to trained SARM model"
)
# Dataset arguments
parser.add_argument(
"--dataset-repo",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
)
parser.add_argument(
"--episode-index",
type=int,
default=0,
help="Index of the episode to visualize (default: 0)"
)
parser.add_argument(
"--task-description",
type=str,
default="perform the task",
help="Task description for the reward model (default: 'perform the task')"
)
# Output arguments
parser.add_argument(
"--output-dir",
type=str,
default="outputs/sarm_inference",
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
)
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
)
parser.add_argument(
"--state-key",
type=str,
default=None,
help="Key for joint states in dataset. If None, auto-detects from dataset"
)
# Visualization options
parser.add_argument(
"--show-frames",
action="store_true",
help="Include sample frames in the visualization"
)
parser.add_argument(
"--num-sample-frames",
type=int,
default=8,
help="Number of sample frames to show (default: 8)"
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[14, 8],
help="Figure size as width height (default: 14 8)"
)
# Device
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to run inference on (cuda/cpu, default: auto-detect)"
)
return parser.parse_args()
def load_episode_data(
dataset: LeRobotDataset,
episode_index: int,
image_key: str,
state_key: str | None = None
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
"""
Load all frames and states from a specific episode.
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode to load
image_key: Key for accessing images in the dataset
state_key: Key for accessing joint states (auto-detected if None)
Returns:
Tuple of (frames, states, start_index, end_index, task_description)
"""
# Get episode boundaries
episode_data = dataset.meta.episodes
start_idx = episode_data["dataset_from_index"][episode_index]
end_idx = episode_data["dataset_to_index"][episode_index]
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
# Auto-detect state key if not provided
if state_key is None:
first_item = dataset[start_idx]
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
if state_keys:
state_key = state_keys[0]
logger.info(f"Auto-detected state key: {state_key}")
# Get task description from the dataset if available
task_description = None
first_item = dataset[start_idx]
if "task" in first_item:
task_description = first_item["task"]
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
# Load all frames and states from the episode
frames = []
states = []
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
item = dataset[idx]
# Get image
img = item[image_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Handle different image formats (C, H, W) or (H, W, C)
if img.shape[0] in [1, 3]: # Channel first
img = np.transpose(img, (1, 2, 0))
# Convert to uint8 if needed
if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
frames.append(img)
# Get state if available
if state_key and state_key in item:
state = item[state_key]
if isinstance(state, torch.Tensor):
state = state.cpu().numpy()
states.append(state)
frames = np.array(frames)
states = np.array(states) if states else None
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
if states is not None:
logger.info(f"Loaded states with shape {states.shape}")
return frames, states, start_idx, end_idx, task_description
@torch.no_grad()
def run_inference(
model: SARMRewardModel,
frames: np.ndarray,
states: Optional[np.ndarray],
task_description: str,
dataset_stats: dict | None = None,
state_key: str = "observation.state",
batch_size: int = 32
) -> tuple[np.ndarray, np.ndarray]:
"""
Run SARM inference on video frames and joint states.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args:
model: SARM model
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim)
task_description: Task description text
dataset_stats: Dataset statistics for state normalization (same as training)
state_key: Key for state in dataset_stats
batch_size: Batch size for processing slices
Returns:
Tuple of (progress_predictions, stage_predictions)
- progress_predictions: (num_frames,)
- stage_predictions: (num_frames, num_stages)
"""
logger.info("Encoding video frames with CLIP...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with CLIP...")
text_embedding = model.encode_text(task_description)
# Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
if states is not None:
state_embeddings = torch.tensor(states, dtype=torch.float32)
# Normalize states using dataset stats (same as training processor)
if dataset_stats is not None and state_key in dataset_stats:
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
else:
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
else:
state_embeddings = None
video_slices = []
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using symmetric bidirectional pattern:
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
# Boundary handling: clamp to [0, last_valid]
deltas = model.config.observation_delta_indices
last_valid = len(video_embeddings) - 1
frame_indices = []
for delta in deltas:
idx = current_frame + delta
idx = max(0, min(idx, last_valid)) # Clamp to valid range
frame_indices.append(idx)
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
if state_embeddings is not None:
state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice)
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None:
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
# Pad states to max_state_dim (same as training processor)
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
else:
state_slices = None
logger.info("Running SARM inference on all slices...")
# Process in batches
all_progress = []
all_stages = []
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
batch_video = video_slices[i:i + batch_size].to(model.device)
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
batch_size_actual = batch_video.shape[0]
# Replicate text embedding for batch
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
# Get predictions
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
batch_video, batch_text, batch_states
)
# Extract predictions at the "current frame" position
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
# the current frame is at position 5 (0-indexed)
current_frame_idx = 5
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
all_progress.extend(batch_progress)
all_stages.extend(batch_stages)
return np.array(all_progress), np.array(all_stages)
def compute_ground_truth_progress(
dataset: LeRobotDataset,
episode_index: int,
temporal_proportions: dict[str, float],
subtask_names_ordered: list[str],
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""
Compute ground truth progress and stage labels for an episode using annotations.
Uses SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode
temporal_proportions: Dict mapping subtask name to proportion
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
Returns:
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
"""
# Load episode metadata
episodes_df = dataset.meta.episodes.to_pandas()
# Check if annotations exist
if "subtask_names" not in episodes_df.columns:
logger.warning("No subtask_names column found in episodes metadata")
return None, None
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
logger.warning(f"No annotations found for episode {episode_index}")
return None, None
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
# Get episode boundaries
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
num_frames = ep_end - ep_start
# Get temporal proportions as ordered list
temporal_proportions_list = [
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
]
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
logger.info(f"Subtask names in episode: {ep_subtask_names}")
logger.info(f"Subtask start frames: {subtask_start_frames}")
logger.info(f"Subtask end frames: {subtask_end_frames}")
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
# Compute ground truth for each frame
gt_progress = np.zeros(num_frames)
gt_stages = np.zeros(num_frames, dtype=np.int32)
for frame_rel in range(num_frames):
# Find which subtask this frame belongs to
found = False
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
if frame_rel >= start_frame and frame_rel <= end_frame:
# Found the subtask - get its global index
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
# Compute τ_t using utility function
tau = compute_tau(frame_rel, start_frame, end_frame)
# Compute cumulative progress using utility function
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
found = True
break
if not found:
# Handle frames outside annotated subtasks
if frame_rel < subtask_start_frames[0]:
gt_progress[frame_rel] = 0.0
gt_stages[frame_rel] = 0
elif frame_rel > subtask_end_frames[-1]:
gt_progress[frame_rel] = 1.0
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
else:
# Between subtasks - find previous subtask
for j in range(len(ep_subtask_names) - 1):
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
name = ep_subtask_names[j]
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
break
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
return gt_progress, gt_stages
def visualize_predictions(
frames: np.ndarray,
progress_predictions: np.ndarray,
stage_predictions: np.ndarray,
task_description: str,
output_path: Path,
num_sample_frames: int = 8,
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None,
ground_truth_progress: np.ndarray | None = None,
ground_truth_stages: np.ndarray | None = None,
):
"""
Create visualization of SARM predictions with optional ground truth comparison.
Args:
frames: Video frames (num_frames, H, W, C)
progress_predictions: Progress predictions (num_frames,)
stage_predictions: Stage probabilities (num_frames, num_stages)
task_description: Task description
output_path: Path to save the figure
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
ground_truth_progress: Optional ground truth progress array (num_frames,)
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
"""
num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
# Use subtask names if available, otherwise use generic labels
if subtask_names is not None and len(subtask_names) == num_stages:
stage_labels = subtask_names
else:
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
# Create figure with progress plot, stage plot, and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
ax_progress = fig.add_subplot(gs[0])
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
ax_frames = fig.add_subplot(gs[2])
frame_indices = np.arange(len(progress_predictions))
# Plot 1: Progress over time
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
# Plot ground truth if available
if ground_truth_progress is not None:
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
linestyle='--', label='Ground Truth Progress')
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
# Add statistics box
stats_text = (
f'Frames: {len(progress_predictions)}\n'
f'Final Progress: {progress_predictions[-1]:.3f}\n'
f'Max Progress: {progress_predictions.max():.3f}\n'
f'Mean Progress: {progress_predictions.mean():.3f}'
)
if ground_truth_progress is not None:
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
stats_text += f'\nMSE vs GT: {mse:.4f}'
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=stage_labels)
# Plot ground truth stage as vertical bands or markers
if ground_truth_stages is not None:
# Find stage transition points in ground truth
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
for change_idx in stage_changes:
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
# Add small markers at bottom showing GT stage
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
c=[stage_colors[s] for s in ground_truth_stages[::30]],
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1)
ax_stages.grid(True, alpha=0.3)
# Adjust legend based on number of stages and label lengths
if num_stages <= 5:
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
else:
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
if temporal_proportions is not None and subtask_names is not None:
cumulative_progress = 0.0
for i, name in enumerate(stage_labels):
if name in temporal_proportions:
# Find approximate frame where this stage should end
stage_end_progress = cumulative_progress + temporal_proportions[name]
# Find frame index closest to this progress
progress_diffs = np.abs(progress_predictions - stage_end_progress)
stage_end_frame = np.argmin(progress_diffs)
# Draw vertical line
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested)
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved visualization to {output_path}")
plt.close()
def main():
args = parse_args()
# Setup device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
logger.info(f"Using device: {device}")
# Load model
logger.info(f"Loading SARM model from {args.model_id}...")
model = SARMRewardModel.from_pretrained(args.model_id)
model.to(device)
model.eval()
logger.info("Model loaded successfully")
# Load dataset
logger.info(f"Loading dataset {args.dataset_repo}...")
dataset = LeRobotDataset(args.dataset_repo)
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
# Validate episode index
if args.episode_index >= len(dataset.meta.episodes):
raise ValueError(
f"Episode index {args.episode_index} out of range. "
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
image_key = args.image_key if args.image_key is not None else model.config.image_key
state_key = args.state_key if args.state_key is not None else model.config.state_key
logger.info(f"Using image key: {image_key}")
logger.info(f"Using state key: {state_key}")
# Load dataset stats for state normalization (same as training)
dataset_stats = load_stats(dataset.root)
if dataset_stats:
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
else:
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
# Load episode data
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
dataset, args.episode_index, image_key, state_key
)
# Use task description from dataset if available, otherwise use command-line argument
task_description = dataset_task if dataset_task is not None else args.task_description
logger.info(f"Using task description: '{task_description}'")
# Run inference
progress_predictions, stage_predictions = run_inference(
model, frames, states, task_description,
dataset_stats=dataset_stats, state_key=state_key
)
# Extract subtask names and temporal proportions from model config if available
subtask_names = None
temporal_proportions = None
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model config
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
temporal_proportions = {
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
}
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None:
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Also extract subtask names from proportions if not already set
if subtask_names is None:
subtask_names = sorted(temporal_proportions.keys())
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
# Compute ground truth progress if annotations are available
ground_truth_progress = None
ground_truth_stages = None
if temporal_proportions is not None and subtask_names is not None:
logger.info("Attempting to compute ground truth progress from annotations...")
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
dataset,
args.episode_index,
temporal_proportions,
subtask_names
)
if ground_truth_progress is None:
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
else:
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
visualize_predictions(
frames,
progress_predictions,
stage_predictions,
task_description,
output_path,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions,
ground_truth_progress=ground_truth_progress,
ground_truth_stages=ground_truth_stages,
)
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
save_dict = {
'progress': progress_predictions,
'stages': stage_predictions
}
if ground_truth_progress is not None:
save_dict['gt_progress'] = ground_truth_progress
save_dict['gt_stages'] = ground_truth_stages
np.savez(predictions_path, **save_dict)
logger.info(f"Saved predictions to {predictions_path}")
logger.info(f"\nVisualization: {output_path}")
if __name__ == "__main__":
main()
+2 -19
View File
@@ -64,26 +64,9 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM)
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
def validate(self):
# Validate RA-BC configuration
if self.use_rabc and not self.reward_model_path:
raise ValueError(
"RA-BC is enabled (use_rabc=True) but no reward_model_path provided. "
"Please specify a pre-trained reward model (e.g., SARM) path."
)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
+2 -10
View File
@@ -999,18 +999,10 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_values
else:
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) == 1:
# 1D array - can assign directly
df[feature_name] = feature_slice
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
# 2D array with single column - flatten it
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
elif len(feature_slice.shape) == 2:
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
df[feature_name] = feature_slice.tolist()
else:
# Higher dimensional - convert to list
df[feature_name] = [row.tolist() for row in feature_slice]
df[feature_name] = feature_slice
frame_idx = end_idx
# Write using the same chunk/file structure as source
@@ -1,146 +0,0 @@
# LeRobot Embedding Generation Script
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
## Overview
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
- **Task embeddings**: Language command embeddings using MiniLM
- **Image embeddings**: Frame embeddings using DinoV2
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
## Supported Encoders
### Image Encoders (DinoV2)
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
### Language Encoders (MiniLM)
MiniLM is a lightweight sentence transformer model:
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
## Usage
### Basic Command
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--push-to-hub
```
### Lightweight Version (No Videos)
Removes video files to significantly reduce storage:
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
```
## Output
The script adds new features to your dataset:
### New Features
1. **`task_embedding`**: Language embedding for each frame
- Shape: `[384]` (MiniLM)
- One embedding per frame based on its task
2. **`{camera_key}_embedding`**: Image embedding for each camera view
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
### Using Embeddings in Training
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load dataset with embeddings
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
# Access embeddings
item = dataset[0]
task_emb = item["task_embedding"] # Shape: [384]
img_emb = item["observation.images.top_embedding"] # Shape: [768]
# Use in your policy
# Instead of running encoders during training, use pre-computed embeddings
```
## Extending with New Encoders
The script is designed to be easily extensible. To add a new encoder:
### 1. Create Encoder Class
```python
class MyCustomImageEncoder(ImageEncoder):
"""Your custom image encoder."""
def __init__(self, device: str = "cuda"):
super().__init__(device)
# Load your model
self.model = load_my_model()
self.model = self.model.to(self.device)
self.model.eval()
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
# Your encoding logic here
embeddings = []
for img in images:
emb = self.model(img)
embeddings.append(emb)
return np.array(embeddings)
@property
def embedding_dim(self) -> int:
"""Return embedding dimension."""
return 512 # Your embedding dimension
```
### 2. Add to Factory Function
```python
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
# Add your encoder
"my_custom": lambda: MyCustomImageEncoder(device=device),
}
# ... rest of function
```
## Validating Embeddings
After generating embeddings, you can validate them using `validate_embeddings.py`:
```bash
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 20
```
@@ -1,147 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageEncoder:
"""Base class for image encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
raise NotImplementedError
class DinoV2Encoder(ImageEncoder):
"""DinoV2 image encoder.
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
"""
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
super().__init__(device)
self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
self.model = self.model.to(self.device)
self.model.eval()
# DinoV2 preprocessing
from torchvision import transforms
self.transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
embeddings = []
with torch.inference_mode():
for i in range(0, len(images), self.batch_size):
batch_images = images[i : i + self.batch_size]
# Convert numpy arrays to PIL Images and apply transforms
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
# Get embeddings
batch_embeddings = self.model(tensors).cpu().numpy()
embeddings.append(batch_embeddings)
return np.concatenate(embeddings, axis=0)
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension based on model size."""
if "vits14" in self.model_name:
return 384 # DinoV2 ViT-S/14
elif "vitb14" in self.model_name:
return 768 # DinoV2 ViT-B/14
elif "vitl14" in self.model_name:
return 1024 # DinoV2 ViT-L/14
else:
return 768 # Default to ViT-B/14
class LanguageEncoder:
"""Base class for language encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
raise NotImplementedError
class MiniLMEncoder(LanguageEncoder):
"""MiniLM language encoder.
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
Supports L6 and L12 model sizes.
"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
super().__init__(device)
self.model_name = model_name
logger.info(f"Loading MiniLM model: {model_name}")
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
def _mean_pooling(self, model_output, attention_mask):
"""Mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
with torch.inference_mode():
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
model_output = self.model(**encoded_input)
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
return embeddings.cpu().numpy()
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension."""
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings
@@ -1,329 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
This script:
1. Loads a v3.0 LeRobot dataset from the hub
2. Computes embeddings for tasks (language commands) and frames (images)
3. Stores embeddings as new features in the dataset
4. Optionally removes video files to reduce size
5. Pushes the converted dataset to the hub
Current supported encoders:
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
- Language: MiniLM (minilm-l6, minilm-l12)
The architecture is extensible - you can add more encoders by:
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
2. Implementing the encode() method and embedding_dim property
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
Usage example:
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
"""
import argparse
import shutil
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import (
DinoV2Encoder,
ImageEncoder,
LanguageEncoder,
MiniLMEncoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
"""Factory function to get image encoder.
To add a new encoder:
1. Create a new class inheriting from ImageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
}
if encoder_name not in encoders:
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
return encoders[encoder_name]()
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
"""Factory function to get language encoder.
To add a new encoder:
1. Create a new class inheriting from LanguageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"minilm-l6": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
),
"minilm-l12": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
),
}
if encoder_name not in encoders:
raise ValueError(
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
)
return encoders[encoder_name]()
def generate_embeddings_for_dataset(
repo_id: str,
output_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
remove_videos: bool = False,
local_dir: Path | None = None,
output_local_dir: Path | None = None,
push_to_hub: bool = False,
):
"""Generate embeddings for a LeRobot dataset.
Args:
repo_id: Source dataset repository ID
output_repo_id: Output dataset repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
remove_videos: Whether to remove video files
local_dir: Local directory for source dataset
output_local_dir: Local directory for output dataset
push_to_hub: Whether to push to hub after conversion
"""
from lerobot.datasets.dataset_tools import modify_features
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
print("Computing task embeddings...")
unique_tasks = dataset.meta.tasks.index.tolist()
task_embeddings = {}
for task in tqdm(unique_tasks, desc="Encoding tasks"):
# Clean up task text
task_clean = task.strip().capitalize().strip(" .,!?-_")
embedding = language_encoder.encode([task_clean])[0]
task_embeddings[task] = embedding
print(f"Computed {len(task_embeddings)} task embeddings")
print("Processing frames and computing embeddings...")
all_task_embeddings = []
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
item = dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)
for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 4:
img = img[0] # (T, C, H, W) -> (C, H, W)
elif img.ndim != 3:
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
img = item[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
all_image_embeddings_dict[cam_key].append(img_np)
print("Computing image embeddings...")
image_embeddings_dict = {}
for cam_key, images in all_image_embeddings_dict.items():
print(f" {cam_key}: {len(images)} images")
embeddings = image_encoder.encode(images)
image_embeddings_dict[cam_key] = embeddings
all_task_embeddings = np.array(all_task_embeddings)
for cam_key in dataset.meta.camera_keys:
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim
add_features_dict = {
"task_embedding": (
all_task_embeddings,
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
),
}
for cam_key in dataset.meta.camera_keys:
add_features_dict[f"{cam_key}_embedding"] = (
image_embeddings_dict[cam_key],
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
)
print("Adding embeddings to dataset...")
remove_features_list = None
if remove_videos:
remove_features_list = dataset.meta.video_keys
output_dataset = modify_features(
dataset=dataset,
add_features=add_features_dict,
remove_features=remove_features_list,
output_dir=output_local_dir,
repo_id=output_repo_id,
)
if remove_videos:
print("Removing video files...")
videos_dir = output_dataset.root / "videos"
if videos_dir.exists():
shutil.rmtree(videos_dir)
print(f"Saved to: {output_dataset.root}")
if push_to_hub:
print(f"Pushing to hub: {output_repo_id}")
output_dataset.push_to_hub(push_videos=not remove_videos)
print("Done!")
def main():
parser = argparse.ArgumentParser(
description="Generate embeddings for LeRobot datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--push-to-hub
# Generate embeddings and remove videos
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--remove-videos \\
--push-to-hub
Available image encoders:
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
Available language encoders:
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
""",
)
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--remove-videos",
action="store_true",
help="Remove video files after generating embeddings",
)
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
parser.add_argument(
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the converted dataset to the hub",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Generate embeddings
generate_embeddings_for_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
remove_videos=args.remove_videos,
local_dir=Path(args.local_dir) if args.local_dir else None,
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()
@@ -1,222 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate pre-computed embeddings against on-the-fly computed embeddings.
Usage:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 10
"""
import argparse
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
from lerobot.datasets.generating_embeddings.generate_embeddings import (
get_image_encoder,
get_language_encoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def validate_embeddings(
original_repo_id: str,
embeddings_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
num_samples: int = 10,
device: str = "cuda",
):
"""Validate pre-computed embeddings against on-the-fly embeddings.
Args:
original_repo_id: Original dataset repository ID
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
num_samples: Number of samples to validate
device: Device to use for encoding
"""
# Load both datasets
print("Loading datasets...")
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
# Verify both datasets have the same number of frames
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
f"Frame count mismatch: original={original_dataset.num_frames}, "
f"embeddings={embeddings_dataset.num_frames}"
)
camera_keys = original_dataset.meta.camera_keys
# Check embedding features exist
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
for feat in expected_features:
if feat not in embeddings_dataset.features:
raise ValueError(f"Embedding feature not found: {feat}")
# Select random sample indices
sample_indices = np.random.choice(
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
)
print(f"Validating {len(sample_indices)} samples...")
# Track statistics
task_similarities = []
image_similarities = {cam: [] for cam in camera_keys}
for idx in tqdm(sample_indices, desc="Validating"):
idx = int(idx)
embeddings_item = embeddings_dataset[idx]
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
original_item = original_dataset[idx]
# Get task and compute embedding
task = original_item["task"]
# Clean up task text (same as in generate_embeddings.py)
task_clean = task.strip().capitalize().strip(" .,!?-_")
onthefly_task_emb = language_encoder.encode([task_clean])[0]
# Get images and compute embeddings
onthefly_image_embs = {}
for cam in camera_keys:
img = original_item[cam]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3: # (C, H, W)
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape: {img.shape}")
else:
img_np = np.array(img)
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
# Task embedding comparison
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
task_similarities.append(task_sim)
# Image embedding comparison
for cam in camera_keys:
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
image_similarities[cam].append(img_sim)
# Results
print("\nResults:")
task_sim_threshold = 0.99
img_sim_threshold = 0.99
task_mean_sim = np.mean(task_similarities)
task_pass = task_mean_sim >= task_sim_threshold
print(f" Task: {task_mean_sim:.4f} {'' if task_pass else ''}")
for cam in camera_keys:
cam_mean_sim = np.mean(image_similarities[cam])
cam_pass = cam_mean_sim >= img_sim_threshold
print(f" {cam}: {cam_mean_sim:.4f} {'' if cam_pass else ''}")
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
print()
if task_pass and image_pass:
print("✓ PASSED")
else:
print("✗ FAILED")
def main():
parser = argparse.ArgumentParser(
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
--original-repo-id lerobot/utokyo_xarm_bimanual \\
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--num-samples 20
""",
)
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
parser.add_argument(
"--embeddings-repo-id",
type=str,
required=True,
help="Dataset with pre-computed embeddings repository ID",
)
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of samples to validate (default: 10)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Validate embeddings
validate_embeddings(
original_repo_id=args.original_repo_id,
embeddings_repo_id=args.embeddings_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
num_samples=args.num_samples,
device=args.device,
)
if __name__ == "__main__":
main()
+6 -24
View File
@@ -712,15 +712,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
# Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
# Build a mapping: absolute_index -> relative_index_in_filtered_dataset
self._absolute_to_relative_idx = None
if self.episodes is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -839,7 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -856,8 +847,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Determine requested episodes
if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
# Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
@@ -939,11 +932,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -966,16 +955,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
# Map absolute indices to relative indices if needed
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
result[key] = torch.stack(self.hf_dataset[key][q_idx])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
result[key] = torch.stack(self.hf_dataset[q_idx][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
@@ -1515,7 +1498,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj._absolute_to_relative_idx = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
-151
View File
@@ -1,151 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM Temporal Sampler for reward model training.
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
- 1 initial frame + 4 frames before + current + 3 frames after
Boundary handling: clamp to first/last frame when indices go out of bounds.
This enables truly uniform sampling across entire episodes.
"""
import logging
from typing import Iterator, Optional
import numpy as np
import torch
from torch.utils.data import Sampler
import random
class SARMTemporalSampler(Sampler):
"""
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
SARM uses 9 frames per sample:
- Frame 0: Initial frame of the episode (always frame 0)
- Frames 1-8: Symmetric context around current frame
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
This enables truly uniform sampling across entire episodes.
Args:
dataset_from_index: Start indices of episodes (global dataset indices)
dataset_to_index: End indices of episodes (global dataset indices)
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
shuffle: Whether to shuffle sampling order
seed: Random seed for reproducibility
samples_per_epoch: Number of samples per epoch (default: 6400)
min_episode_length: Minimum episode length to include (default: 1)
"""
def __init__(
self,
dataset_from_index: np.ndarray,
dataset_to_index: np.ndarray,
frame_gap: int = 30,
shuffle: bool = True,
seed: Optional[int] = None,
samples_per_epoch: int = 6400,
min_episode_length: int = 1,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.frame_gap = frame_gap
self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch
self.min_episode_length = min_episode_length
if seed is not None:
self.seed = seed
random.seed(seed)
np.random.seed(seed)
self.generator = torch.Generator().manual_seed(seed)
else:
self.generator = torch.Generator()
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
self._compute_valid_positions()
logging.info(
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{len(self.all_valid_positions)} positions (uniform sampling), "
f"{self.samples_per_epoch} samples per epoch, "
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
)
def _compute_valid_positions(self):
"""Compute valid episodes and ALL sampling positions for uniform sampling.
With symmetric bidirectional sampling, we can sample from ANY frame:
- Early frames: backward indices clamp to first frame
- Late frames: forward indices clamp to last frame
"""
self.valid_episodes = []
self.all_valid_positions = []
for ep_idx in range(len(self.dataset_from_index)):
ep_start = self.dataset_from_index[ep_idx]
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
# Include all episodes with at least min_episode_length frames
if episode_length >= self.min_episode_length:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Include ALL positions in the episode (truly uniform sampling)
for pos in range(ep_start, ep_end):
self.all_valid_positions.append(pos)
self.valid_episodes = np.array(self.valid_episodes)
self.all_valid_positions = np.array(self.all_valid_positions)
if len(self.all_valid_positions) == 0:
raise ValueError(
f"No valid sampling positions found! "
f"Check that episodes have at least {self.min_episode_length} frames."
)
def __len__(self) -> int:
return self.samples_per_epoch
def __iter__(self) -> Iterator[int]:
"""
Yields global dataset indices for uniform sampling across episodes.
Each yielded index represents the "current frame" position.
The dataset's observation_delta_indices then handles loading:
- Frame 0: Episode initial frame (via large negative delta clamping)
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
For early frames: backward indices clamp to first frame (progress ~0%)
For late frames: forward indices clamp to last frame (progress ~100%)
"""
if self.shuffle:
# Randomly sample from all valid positions
for _ in range(self.samples_per_epoch):
idx = np.random.randint(0, len(self.all_valid_positions))
yield int(self.all_valid_positions[idx])
else:
# Sequential sampling with wrap-around
for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx])
+4 -18
View File
@@ -28,7 +28,6 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -104,9 +103,7 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -114,26 +111,15 @@ def load_nested_dataset(
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
def get_parquet_num_frames(parquet_path: str | Path) -> int:
+9 -57
View File
@@ -21,22 +21,7 @@ import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import (
ACTION,
LIBERO_KEY_EEF_MAT,
LIBERO_KEY_EEF_POS,
LIBERO_KEY_EEF_QUAT,
LIBERO_KEY_GRIPPER_QPOS,
LIBERO_KEY_GRIPPER_QVEL,
LIBERO_KEY_JOINTS_POS,
LIBERO_KEY_JOINTS_VEL,
LIBERO_KEY_PIXELS_AGENTVIEW,
LIBERO_KEY_PIXELS_EYE_IN_HAND,
OBS_ENV_STATE,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
@dataclass
@@ -261,61 +246,28 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(4,),
)
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(3, 3),
)
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
-39
View File
@@ -14,16 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -37,41 +33,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
+21 -69
View File
@@ -28,6 +28,7 @@ import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.transform_utils import quat2axisangle
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
@@ -174,36 +175,11 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"robot_state": spaces.Dict(
{
"eef": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
"quat": spaces.Box(
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
),
"mat": spaces.Box(
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
),
}
),
"gripper": spaces.Dict(
{
"qpos": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
"qvel": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
}
),
"joints": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
}
),
}
"agent_pos": spaces.Box(
low=AGENT_POS_LOW,
high=AGENT_POS_HIGH,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
),
}
)
@@ -215,7 +191,6 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -237,48 +212,23 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
eef_pos = raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
joint_pos = raw_obs.get("robot0_joint_pos")
joint_vel = raw_obs.get("robot0_joint_vel")
obs = {
"pixels": images,
"robot_state": {
"eef": {
"pos": eef_pos, # (3,)
"quat": eef_quat, # (4,)
"mat": eef_mat, # (3, 3)
},
"gripper": {
"qpos": gripper_qpos, # (2,)
"qvel": gripper_qvel, # (2,)
},
"joints": {
"pos": joint_pos, # (7,)
"vel": joint_vel, # (7,)
},
},
}
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
if self.obs_type == "pixels":
return {"pixels": images.copy()}
if self.obs_type == "pixels_agent_pos":
# Validate required fields are present
if eef_pos is None or eef_quat is None or gripper_qpos is None:
raise ValueError(
f"Missing required robot state fields in raw observation. "
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
f"gripper_qpos={gripper_qpos is not None}"
)
return obs
return {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -405,10 +355,12 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
+6 -20
View File
@@ -29,22 +29,10 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.utils import get_channel_first_image_shape
def _convert_nested_dict(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _convert_nested_dict(v)
elif isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
else:
result[k] = v
return result
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -90,14 +78,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
if "agent_pos" in observations:
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations
-20
View File
@@ -35,7 +35,6 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
@@ -104,10 +103,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
@@ -327,14 +322,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
@@ -418,13 +405,6 @@ def make_policy(
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, 'stats'):
kwargs["dataset_stats"] = ds_meta.stats
if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
+38 -27
View File
@@ -1,38 +1,49 @@
# Real-Time Chunking (RTC)
# Real-Time Chunking (RTC) Module
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
This module implements Real-Time Chunking and related adaptive inference techniques for robotics policies in LeRobot.
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
## Overview
---
Real-Time Chunking (RTC) addresses the challenge of real-time inference in action chunking policies by treating chunk generation as an inpainting problem. It strategically handles overlapping timesteps between action chunks using prefix attention mechanisms.
## Citation
It is particularly effective for handling long-horizon inference in robotics policies.
If you use Real-Time Chunking in your work, please cite:
## Integration with Policies
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
RTC can be integrated with any policy that supports flow mathicng for chunking:
@misc{black2025realtimeexecutionactionchunking,
title={Real-Time Execution of Action Chunking Flow Policies},
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
year={2025},
eprint={2506.07339},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2506.07339},
}
- **SmolVLA**: Vision-language-action model with RTC support
- **Pi0**: Action prediction model with adaptive chunking
- **Pi05**: Action prediction model with adaptive chunking
## Original Implementation
This implementation is based on Physical Intelligence's Kinetix RTC:
- [Original RTC implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214)
- [Kinetix GitHub Repository](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
## References
- [Real Time Chunking Paper](https://www.physicalintelligence.company/research/real_time_chunking)
- [Physical Intelligence Kinetix](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
## How to run
### Check with data from the dataset
```bash
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--seed=42
```
---
This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory.
## License
The example output should look like this:
![Flow Matching with RTC](./flow_matching.png)
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk).
@@ -111,3 +111,7 @@ class RTCDebugVisualizer:
if not ax.yaxis.get_label().get_text():
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
# Add legend if label provided and this is the first dimension
if label and dim_idx == 0:
ax.legend(loc="best", fontsize=8)
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

-34
View File
@@ -1,34 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.modeling_sarm import (
SARMRewardModel,
SARMTransformer,
)
from lerobot.policies.sarm.processor_sarm import (
SARMEncodingProcessorStep,
make_sarm_pre_post_processors,
)
__all__ = [
"SARMConfig",
"SARMRewardModel",
"SARMTransformer",
"SARMEncodingProcessorStep",
"make_sarm_pre_post_processors",
]
@@ -1,186 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("sarm")
@dataclass
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
# CLIP params
image_dim: int = 512
text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
# Architecture params
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
max_state_dim: int = 32
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training params
batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding
dropout: float = 0.1
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
pretrained_model_path: str | None = None
device: str | None = None
# Processor settings
image_key: str = "observation.images.top" # Key for image used from the dataset
# State key in the dataset (for normalization)
state_key: str = "observation.state"
# Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
# Output features
output_features: dict = field(default_factory=lambda: {
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
})
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"LANGUAGE": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self):
super().__post_init__()
# Add the image_key as VISUAL
if self.image_key:
self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3),
type=FeatureType.VISUAL
)
# Add state_key as STATE
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE
)
# Update output features with actual dimensions
self.output_features["stage"] = PolicyFeature(
shape=(self.num_frames, self.num_stages),
type=FeatureType.REWARD
)
self.output_features["progress"] = PolicyFeature(
shape=(self.num_frames, 1),
type=FeatureType.REWARD
)
# Validate configuration
if self.hidden_dim % self.num_heads != 0:
raise ValueError(
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
)
if self.max_length != self.num_frames:
raise ValueError(
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
)
if self.num_stages < 2:
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for SARM training."""
return AdamWConfig(
lr=5e-5,
weight_decay=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=5e-5,
decay_lr=5e-6,
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
"""Validate input and output features."""
pass
@property
def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
The model uses 9 frames with symmetric context around current frame:
- Frame 0: Initial frame of the episode (clamped via large negative delta)
- Frames 1-8: Symmetric context: 4 before + current + 3 after
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling (done by dataset loader):
- Early frames: backward indices clamp to 0 (first frame)
- Late frames: forward indices clamp to episode end (last frame)
This enables truly uniform sampling across entire episodes.
Returns:
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
"""
initial_frame_delta = -1_000_000
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
symmetric_deltas = [
-4 * self.frame_gap,
-3 * self.frame_gap,
-2 * self.frame_gap,
-1 * self.frame_gap,
0, # current frame
1 * self.frame_gap,
2 * self.frame_gap,
3 * self.frame_gap,
]
return [initial_frame_delta] + symmetric_deltas
@property
def action_delta_indices(self) -> None:
"""SARM is a reward model, not an action policy."""
return None
@property
def reward_delta_indices(self) -> None:
"""SARM doesn't use delta rewards."""
return None
-650
View File
@@ -1,650 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Union, Optional
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module):
"""
SARM Transformer model for stage-aware reward prediction.
This model has a dual-head architecture:
1. Stage estimator: Predicts the high-level task stage (classification)
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
"""
def __init__(
self,
video_dim: int = 512,
text_dim: int = 512,
max_state_dim: int = 32,
hidden_dim: int = 768,
num_heads: int = 12,
num_layers: int = 8,
num_stages: int = 5,
max_length: int = 9,
dropout: float = 0.1,
temporal_proportions: list[float] | None = None
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
self.num_stages = num_stages
self.max_state_dim = max_state_dim
if temporal_proportions is None:
raise ValueError(
"temporal_proportions is required for SARM. "
"Provide subtask annotations in your dataset or set temporal_proportions in config."
)
# ᾱ_k: proportion for each stage
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
# P_k: cumulative proportion up to stage k (P_0 = 0)
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative)
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
# Position embedding only for the first frame
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification)
self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, num_stages)
)
# Subtask estimator head (regression)
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential(
nn.Linear(subtask_input_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 1),
nn.Sigmoid()
)
# Attention mask
self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask
return self.attention_mask
def forward(
self,
video_frames: torch.Tensor,
text_embed: torch.Tensor,
state_features: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass through the SARM transformer.
Args:
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
text_embed: Text embeddings (batch_size, text_dim)
state_features: Joint state features (batch_size, seq_len, state_dim)
Returns:
Tuple of:
- Stage logits for each frame (batch_size, seq_len, num_stages)
- Stage probabilities (batch_size, seq_len, num_stages)
- Progress predictions for each frame (batch_size, seq_len, 1)
"""
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Pad state features to max_state_dim before projection
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features
video_embed = video_embed + state_embed
# Add positional embedding to first video frame
video_embed[:, 0] += self.first_pos_embed
# Combine sequence: [text, video_frames]
sequence = torch.cat([text_embed, video_embed], dim=1)
# Get causal attention mask
seq_length = sequence.shape[1]
attention_mask = self._get_attention_mask(seq_length, sequence.device)
# Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get frame features
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
# Stage estimation
stage_logits = self.stage_head(frame_features) # [batch_size, seq_len, num_stages]
stage_probs = F.softmax(stage_logits, dim=-1) # [batch_size, seq_len, num_stages]
# Get predicted stage indices
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
# Get stage embeddings for conditioning
stage_embeds = self.stage_embedding(stage_indices)
# Concatenate frame features with stage embeddings
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
# Subtask progress estimation (conditioned on stage)
# τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress_preds = compute_cumulative_progress_batch(
tau_preds, stage_indices, self.alpha, self.cumulative_prior
)
return stage_logits, stage_probs, progress_preds
class SARMRewardModel(PreTrainedPolicy):
"""
SARM Reward Model for stage-aware task completion rewards.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This model combines:
- CLIP for encoding video frames AND text descriptions
- SARMTransformer for predicting task stage and progress
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
"""
name = "sarm"
config_class = SARMConfig
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Load temporal proportions from dataset
if config.temporal_proportions is None and dataset_meta is not None:
self._load_temporal_proportions(dataset_meta)
logging.info("Loading CLIP encoder")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,
text_dim=config.text_dim,
max_state_dim=config.max_state_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
num_stages=config.num_stages,
max_length=config.max_length,
dropout=config.dropout,
temporal_proportions=config.temporal_proportions
)
self.sarm_transformer.to(self.device)
logging.info(f"SARM initialized on {self.device}")
def _load_temporal_proportions(self, dataset_meta) -> None:
"""
Load pre-computed temporal proportions from dataset metadata JSON file.
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
"""
import json
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
if not proportions_path.exists():
raise ValueError(
f"Temporal proportions not found at {proportions_path}. "
"Run the subtask annotation tool first to compute and save temporal proportions."
)
with open(proportions_path, "r") as f:
temporal_proportions_dict = json.load(f)
# Sort subtask names for consistent ordering
subtask_names = sorted(temporal_proportions_dict.keys())
self.config.num_stages = len(subtask_names)
self.config.subtask_names = subtask_names
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.clip_model.to(device)
self.sarm_transformer.to(device)
return self
@torch.no_grad()
def encode_images(self, images: np.ndarray) -> np.ndarray:
"""
Encode video frames using CLIP.
Args:
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
Can also be (num_frames, H, W, C) for a single video.
Returns:
Encoded image features (num_videos, num_frames, 512) or (num_frames, 512).
"""
# Handle single video case
single_video = False
if len(images.shape) == 4:
images = images[np.newaxis, ...]
single_video = True
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
all_embeddings = []
for video in images:
video_embeddings = []
# Convert frames to PIL images for CLIP processor
frames = []
for frame in video:
if frame.shape[0] == 3: # Channel first
frame = frame.transpose(1, 2, 0)
if frame.dtype != np.uint8:
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
frames.append(Image.fromarray(frame))
# Batch process frames with CLIP
for i in range(0, len(frames), self.config.clip_batch_size):
batch = frames[i:i + self.config.clip_batch_size]
inputs = self.clip_processor(images=batch, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings from CLIP
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
video_embeddings.append(embeddings)
video_embeddings = torch.cat(video_embeddings)
all_embeddings.append(video_embeddings)
result = torch.stack(all_embeddings).numpy()
if single_video:
result = result[0]
return result
@torch.no_grad()
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Text string or list of text strings.
Returns:
Encoded text features (batch_size, 512) or (512,) for single text.
"""
if isinstance(text, str):
text = [text]
single_text = True
else:
single_text = False
# Use CLIP's tokenizer directly (avoids image processor validation issues)
tokenizer = self.clip_processor.tokenizer
# Process in batches
all_embeddings = []
for i in range(0, len(text), self.config.batch_size):
batch_text = text[i:i + self.config.batch_size]
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embeddings = self.clip_model.get_text_features(**inputs)
all_embeddings.append(text_embeddings.cpu())
result = torch.cat(all_embeddings).numpy()
if single_text:
result = result[0]
return result
@torch.no_grad()
def calculate_rewards(
self,
text_embeddings: Union[np.ndarray, torch.Tensor],
video_embeddings: Union[np.ndarray, torch.Tensor],
state_features: Optional[Union[np.ndarray, torch.Tensor]] = None,
return_all_frames: bool = False,
return_stages: bool = False
) -> Union[np.ndarray, tuple]:
"""
Calculate rewards for given text, video, and state representations.
Args:
text_embeddings: Encoded text representations (batch_size, 512)
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
state_features: Joint state features (batch_size, num_frames, state_dim)
return_all_frames: If True, return rewards for all frames
return_stages: If True, also return stage predictions
Returns:
If return_stages=False:
Reward values (batch_size,) or (batch_size, num_frames)
If return_stages=True:
Tuple of (rewards, stage_probs)
"""
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
if state_features is not None and isinstance(state_features, np.ndarray):
state_features = torch.tensor(state_features, dtype=torch.float32)
# Handle single sample case
if text_embeddings.dim() == 1:
text_embeddings = text_embeddings.unsqueeze(0)
video_embeddings = video_embeddings.unsqueeze(0)
if state_features is not None:
state_features = state_features.unsqueeze(0)
single_sample = True
else:
single_sample = False
# Process in batches
all_rewards = []
all_stage_probs = []
for i in range(0, len(video_embeddings), self.config.batch_size):
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
batch_states = None
if state_features is not None:
batch_states = state_features[i:i + self.config.batch_size].to(self.device)
# Get predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
batch_videos.float(), batch_texts.float(), batch_states.float() if batch_states is not None else None
)
if return_all_frames:
all_rewards.append(progress_preds.squeeze(-1).cpu())
else:
# Return only last frame reward
all_rewards.append(progress_preds[:, -1, 0].cpu())
if return_stages:
all_stage_probs.append(stage_probs.cpu())
rewards = torch.cat(all_rewards).numpy()
if single_sample:
rewards = rewards[0] if not return_all_frames else rewards[0]
if return_stages:
stage_probs = torch.cat(all_stage_probs).numpy()
if single_sample:
stage_probs = stage_probs[0]
return rewards, stage_probs
return rewards
def train(self, mode: bool = True):
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
super().train(mode)
self.clip_model.eval()
self.sarm_transformer.train(mode)
return self
def eval(self):
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
return self.train(False)
def parameters(self):
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
return self.sarm_transformer.parameters()
def get_optim_params(self):
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
return self.parameters()
def reset(self):
"""Required by PreTrainedPolicy but not used for reward models."""
pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for reward models."""
raise NotImplementedError("SARM model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _apply_temporal_augmentation(
self,
video: torch.Tensor,
progress: torch.Tensor,
state: torch.Tensor | None,
max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
This helps the model learn to handle non-monotonic progress (failures, recoveries).
Appends 1-4 reversed frames to simulate going backwards in task progress.
"""
num_reverse = random.randint(1, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original)
reversed_video = video.flip(0)[1:num_reverse + 1]
reversed_progress = progress.flip(0)[1:num_reverse + 1]
# Concatenate and trim
video = torch.cat([video, reversed_video], dim=0)[:max_length]
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
if state is not None:
reversed_state = state.flip(0)[1:num_reverse + 1]
state = torch.cat([state, reversed_state], dim=0)[:max_length]
return video, progress, state
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
"""Pad or trim tensor to target length."""
current_len = tensor.shape[0]
if current_len == target_len:
return tensor
if current_len < target_len:
padding = target_len - current_len
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
return tensor[:target_len]
def forward(self, batch):
"""
Forward pass for SARM reward model training.
Uses annotation-based progress targets following SARM paper Eq. 2:
yt = Pk-1 + α̅k × τt
where:
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
- Pk-1 is cumulative prior (sum of previous subtask proportions)
- α̅k is the temporal proportion for subtask k
Args:
batch: Dictionary with 'observation' containing:
- 'video_features': (B, T, 512) pre-encoded video features
- 'text_features': (B, 512) pre-encoded text features (CLIP)
- 'state_features': (B, T, state_dim) joint state features
- 'stage_labels': (B, T) stage labels from annotations
- 'progress_targets': (B, T, 1) progress targets from annotations
Returns:
Tuple of (total_loss, output_dict with loss components)
"""
observation = batch.get('observation', batch)
# Extract required features
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features').to(self.device)
batch_size = video_features.shape[0]
max_length = self.config.num_frames
# Ensure 3D video features (B, T, D)
if video_features.dim() == 2:
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
if state_features is not None and state_features.dim() == 2:
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
# Get annotation-based progress targets (required for SARM paper formula)
progress_from_annotations = observation.get('progress_targets')
if progress_from_annotations is None:
raise ValueError("progress_targets from annotations is required for SARM training")
progress_from_annotations = progress_from_annotations.to(self.device)
if progress_from_annotations.dim() == 2:
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
# Process each sample: apply temporal REWIND augmentation
processed_videos = []
processed_states = []
progress_targets = []
for i in range(batch_size):
video = video_features[i]
state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
# Ensure correct sequence length
video = self._ensure_sequence_length(video, max_length)
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
if state is not None:
state = self._ensure_sequence_length(state, max_length)
processed_videos.append(video)
progress_targets.append(progress)
if state is not None:
processed_states.append(state)
# Stack into batches
processed_videos = torch.stack(processed_videos)
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
processed_states = torch.stack(processed_states) if processed_states else None
# Get model predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
processed_videos, text_features, processed_states
)
# Compute progress loss (MSE)
progress_loss = F.mse_loss(progress_preds, progress_targets)
output_dict = {'progress_loss': progress_loss.item()}
total_loss = progress_loss
# Compute stage loss (cross-entropy)
stage_labels = observation.get('stage_labels')
if stage_labels is None:
raise ValueError("stage_labels from annotations is required for SARM training")
stage_labels = stage_labels.to(self.device)
if stage_labels.dim() == 1:
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
stage_loss = compute_stage_loss(stage_logits, stage_labels)
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
# Misaligned loss: 20% probability
if random.random() < 0.2:
shuffle_idx = torch.randperm(batch_size, device=self.device)
_, _, misaligned_preds = self.sarm_transformer(
processed_videos, text_features[shuffle_idx], processed_states
)
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
total_loss = total_loss + misaligned_loss
output_dict['misaligned_loss'] = misaligned_loss.item()
output_dict['total_loss'] = total_loss.item()
return total_loss, output_dict
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
_, _, num_stages = stage_logits.shape
stage_logits_flat = stage_logits.reshape(-1, num_stages)
target_stages_flat = target_stages.reshape(-1)
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
return loss
-644
View File
@@ -1,644 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
import torch
from PIL import Image
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
from lerobot.processor.core import TransitionKey
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import (
ProcessorStep,
PolicyProcessorPipeline,
PolicyAction,
DeviceProcessorStep,
AddBatchDimensionProcessorStep,
NormalizerProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
from_tensor_to_numpy,
)
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.configs.types import PolicyFeature, FeatureType
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):
"""ProcessorStep that encodes images and text with CLIP."""
def __init__(
self,
config: SARMConfig,
image_key: str | None = None,
dataset_meta = None,
dataset_stats: dict | None = None,
):
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
self.subtask_names = self.config.subtask_names
self.device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
)
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
return ep_idx
return 0
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index."""
if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
# If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for symmetric bidirectional pattern.
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Backward indices clamp to ep_start (first frame)
- Forward indices clamp to ep_end - 1 (last frame)
"""
indices = []
indices.append(ep_start) # Initial frame is always episode start
# Symmetric context: 4 before, current, 3 after
num_before = 4
num_after = 3
last_valid_frame = ep_end - 1
# Frames before current (clamp to first frame)
for i in range(num_before, 0, -1):
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
indices.append(idx)
# Current frame
indices.append(frame_idx)
# Frames after current (clamp to last frame)
for i in range(1, num_after + 1):
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
indices.append(idx)
return torch.tensor(indices)
def _compute_episode_metadata(
self,
frame_indices: np.ndarray,
episode_indices: np.ndarray,
num_frames: int,
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute episode metadata for all samples.
Returns:
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
"""
absolute_indices_list = []
remaining_lengths = []
episode_lengths = []
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_lengths.append(ep_end - ep_start)
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item())
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
def _compute_stage_and_progress_for_frame(
self,
current_frame: int,
subtask_names: list,
subtask_start_frames: list,
subtask_end_frames: list,
transition_smoothing_frames: int = 15,
) -> tuple[int, float, dict[int, float] | None]:
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
Implements SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Additionally computes soft stage labels near transitions to mitigate discrete jumps
in the stage classifier. Near stage boundaries, labels are blended between adjacent
stages to encourage smoother predictions.
Args:
current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode
subtask_start_frames: List of subtask start frames
subtask_end_frames: List of subtask end frames
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
Returns:
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
- stage_idx: Hard stage index (for compatibility)
- cumulative_progress: Progress value in [0, 1]
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
"""
# Get temporal proportions as list for compute_cumulative_progress
temporal_proportions_list = [
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
]
num_stages = len(self.subtask_names)
# Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask, get its global index
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Compute τ_t using utility function (Paper Formula 2)
tau = compute_tau(current_frame, start_frame, end_frame)
# Compute cumulative progress using utility function (Paper Formula 2)
cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list
)
# Compute soft stage labels near transitions
soft_stage_labels = None
frames_from_start = current_frame - start_frame
frames_to_end = end_frame - current_frame
if frames_from_start < transition_smoothing_frames and j > 0:
# Near start of stage - blend with previous stage
blend = frames_from_start / transition_smoothing_frames
prev_name = subtask_names[j - 1]
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
# Near end of stage - blend with next stage
blend = frames_to_end / transition_smoothing_frames
next_name = subtask_names[j + 1]
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
return stage_idx, cumulative_progress, soft_stage_labels
# No matching subtask found
if current_frame < subtask_start_frames[0]:
return 0, 0.0, None
elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0, None
else:
# Between subtasks - use previous subtask's end state (tau = 1.0)
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Completed subtask, so tau = 1.0
cumulative_progress = compute_cumulative_progress_batch(
1.0, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress, None
return 0, 0.0, None
def _compute_labels_for_sample(
self,
frame_idx: int,
ep_idx: int,
seq_len: int,
episodes_df: pd.DataFrame,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Before episode start: clamp to frame 0 (progress ~0%)
- After episode end: clamp to last frame (progress ~100%)
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
Args:
frame_idx: The frame index for this sample
ep_idx: The episode index
seq_len: Number of frames in the sequence
episodes_df: DataFrame with episode metadata
Returns:
Tuple of (stage_labels, progress_targets, soft_stage_labels):
- stage_labels: (T,) hard stage indices
- progress_targets: (T, 1) progress values
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
"""
# Check if episode has valid annotations
if ep_idx >= len(episodes_df):
return None, None, None
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
return None, None, None
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
ep_length = ep_end - ep_start
last_valid_frame = ep_length - 1
num_stages = len(self.subtask_names)
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
soft_labels_list = [] # List of soft label dicts (or None)
has_any_soft_labels = False
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
num_before = 4
num_after = 3
for i in range(seq_len):
if i == 0:
# Position 0: Initial frame of the episode
current_frame = 0 # Relative to episode start
elif i <= num_before:
# Positions 1-4: frames before current (with clamping to first frame)
offset = -(num_before - i + 1) * self.config.frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
elif i == num_before + 1:
# Position 5: current frame
current_frame = frame_idx - ep_start
else:
# Positions 6-8: frames after current (with clamping to last frame)
offset = (i - num_before - 1) * self.config.frame_gap
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
)
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
soft_labels_list.append(soft_stage_labels)
if soft_stage_labels is not None:
has_any_soft_labels = True
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
# Convert soft labels to tensor if any exist
soft_stage_labels_tensor = None
if has_any_soft_labels:
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for i, soft_dict in enumerate(soft_labels_list):
if soft_dict is not None:
for stage_idx, prob in soft_dict.items():
soft_stage_labels_tensor[i, stage_idx] = prob
else:
# Use hard one-hot label
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
return stage_labels, progress_targets, soft_stage_labels_tensor
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
Args:
frame_index: Current frame index or tensor of indices
episode_index: Episode index or tensor of indices
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
- stage_labels: (B, T) hard stage indices
- progress_targets: (B, T, 1) progress values
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
"""
if self.temporal_proportions is None or episode_index is None:
return None, None, None
# Normalize inputs to numpy arrays
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length
if video_features is not None and video_features.dim() >= 2:
seq_len = video_features.shape[1]
else:
seq_len = 1
episodes_df = self.dataset_meta.episodes.to_pandas()
num_stages = len(self.subtask_names)
all_stage_labels = []
all_progress_targets = []
all_soft_stage_labels = []
has_any_soft_labels = False
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
int(frame_idx), int(ep_idx), seq_len, episodes_df
)
if stage_labels is None:
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
all_soft_stage_labels.append(None)
else:
all_stage_labels.append(stage_labels)
all_progress_targets.append(progress_targets)
all_soft_stage_labels.append(soft_labels)
if soft_labels is not None:
has_any_soft_labels = True
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
# Stack soft labels if any exist
stacked_soft_labels = None
if has_any_soft_labels:
soft_labels_tensors = []
for i, soft_labels in enumerate(all_soft_stage_labels):
if soft_labels is not None:
soft_labels_tensors.append(soft_labels)
else:
# Create one-hot from hard labels
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for t in range(seq_len):
one_hot[t, all_stage_labels[i][t]] = 1.0
soft_labels_tensors.append(one_hot)
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition."""
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION)
image = observation.get(self.image_key)
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key
state_data = observation.get(state_key)
if isinstance(state_data, torch.Tensor):
state_tensor = state_data.float()
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Get task description from dataset (complementary_data["task"])
task = comp_data.get('task')
if isinstance(task, list):
# If batch, take first task (assuming same task for all items in batch)
task = task[0] if task else ""
# Encode text with CLIP
batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(task, batch_size)
frame_index = comp_data.get('index')
episode_index = comp_data.get('episode_index')
if frame_index is None:
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
# Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None:
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features
if video_features.dim() >= 2:
num_frames = video_features.shape[1]
else:
num_frames = 1
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
frame_indices, episode_indices, num_frames
)
observation['absolute_frame_indices'] = abs_indices
observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
frame_index, episode_index, video_features
)
if stage_labels is not None:
observation['stage_labels'] = stage_labels
observation['progress_targets'] = progress_targets
if soft_stage_labels is not None:
observation['soft_stage_labels'] = soft_stage_labels
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
@torch.no_grad()
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
"""Encode a batch of images using CLIP.
Args:
images: Batched images with shape: (B, T, C, H, W)
Returns:
Encoded feature vectors with shape (B, T, 512)
"""
batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:])
# Convert to list of PIL images
num_frames = images.shape[0]
images_list = []
for i in range(num_frames):
img = images[i]
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
img = img.transpose(1, 2, 0)
# Handle single channel
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
# Convert to uint8
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
images_list.append(Image.fromarray(img))
# Encode each batch
all_embeddings = []
for i in range(0, num_frames, self.config.clip_batch_size):
batch_imgs = images_list[i:i + self.config.clip_batch_size]
# Process with CLIP
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
all_embeddings.append(embeddings)
# Concatenate all embeddings
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
# Reshape back
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
return all_embeddings
@torch.no_grad()
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Task description text to encode
batch_size: Batch size to replicate for
Returns:
Encoded text features with shape (B, 512)
"""
# Use CLIP's tokenizer directly for text
tokenizer = self.clip_processor.tokenizer
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get text features from CLIP
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
# Replicate for batch (B, 512)
text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features."""
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim)
)
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
type=FeatureType.LANGUAGE,
shape=(self.config.text_dim,)
)
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.max_state_dim)
)
return features
def make_sarm_pre_post_processors(
config: SARMConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create pre-processor and post-processor pipelines for SARM.
The pre-processing pipeline:
1. Adds batch dimension
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. SARMEncodingProcessorStep:
- Encodes images with CLIP
- Pads states to max_state_dim
- Encodes text with CLIP
4. Moves data to device
The post-processing pipeline:
1. Moves data to CPU
"""
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=[DeviceProcessorStep(device="cpu")],
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
-257
View File
@@ -1,257 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn.functional as F
from typing import Sequence, Any
from pydantic import BaseModel, Field
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where:
- M is the number of trajectories (episodes)
- L_{i,k} is the duration of subtask k in trajectory i
- T_i is the total duration of trajectory i
This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length.
Args:
annotations: Dict mapping episode index to SubtaskAnnotation object.
Each annotation has a .subtasks list where each subtask has:
- .name: subtask name
- .timestamps.start: start time as "MM:SS" string
- .timestamps.end: end time as "MM:SS" string
fps: Frames per second (unused, kept for API compatibility)
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k).
Proportions are normalized to sum to 1.0.
"""
subtask_proportions: dict[str, list[float]] = {}
for annotation in annotations.values():
total_duration = 0
durations: dict[str, int] = {}
for subtask in annotation.subtasks:
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Calculate L_{i,k} / T_i for each subtask in this trajectory
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
if not subtask_proportions:
return {}
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to ensure sum = 1
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def compute_tau(
current_frame: int | float,
subtask_start: int | float,
subtask_end: int | float,
) -> float:
"""
Compute within-subtask normalized time τ_t.
Implements part of SARM Paper Formula (2):
τ_t = (t - s_k) / (e_k - s_k) [0, 1]
where:
- t is the current frame
- s_k is the start frame of subtask k
- e_k is the end frame of subtask k
Args:
current_frame: Current frame index (t)
subtask_start: Start frame of the subtask (s_k)
subtask_end: End frame of the subtask (e_k)
Returns:
Within-subtask progress τ_t [0, 1]
"""
subtask_duration = subtask_end - subtask_start
if subtask_duration <= 0:
return 1.0
tau = (current_frame - subtask_start) / subtask_duration
return float(np.clip(tau, 0.0, 1.0))
def compute_cumulative_progress_batch(
tau: torch.Tensor | float,
stage_indices: torch.Tensor | int,
alpha: torch.Tensor | Sequence[float],
cumulative_prior: torch.Tensor | None = None,
) -> torch.Tensor | float:
"""
Compute cumulative normalized progress from within-subtask progress.
This function implements the core formula used in SARM for both:
**Formula 2 (Training labels):**
y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
Used to compute ground-truth progress labels from subtask annotations.
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
- k is the known subtask from annotations
**Formula 4 (Inference predictions):**
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} [0, 1]
Used to convert model outputs to cumulative progress during inference.
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
The formulas are mathematically identical; only the source of inputs differs:
- Training: τ and k from annotations ground-truth labels
- Inference: τ̂ and Ŝ from model predicted progress
where:
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
- τ is within-subtask progress [0, 1]
This ensures:
- y at start of subtask k = P_{k-1}
- y at end of subtask k = P_k
Supports both scalar and batched tensor inputs:
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
Args:
tau: Within-subtask progress τ [0, 1].
For training: computed from frame position in annotated subtask.
For inference: predicted by subtask MLP head.
Scalar float or Tensor with shape (..., 1)
stage_indices: Index of current subtask k (0-indexed).
For training: known from annotations.
For inference: predicted via argmax(stage_probs) (Formula 3).
Scalar int or Tensor with shape (...)
alpha: Temporal proportions with shape (num_stages,) or Sequence[float].
Computed from dataset annotations using Formula 1.
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
If None, will be computed from alpha.
Returns:
Cumulative progress y [0, 1].
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
"""
if not isinstance(tau, torch.Tensor):
if not alpha:
raise ValueError("alpha (temporal_proportions) cannot be empty")
if isinstance(alpha, torch.Tensor):
alpha_list = alpha.tolist()
else:
alpha_list = list(alpha)
if stage_indices < 0 or stage_indices >= len(alpha_list):
raise ValueError(
f"stage_indices {stage_indices} out of range "
f"for {len(alpha_list)} subtasks"
)
# P_{k-1} = sum of proportions for subtasks 0 to k-1
P_k_minus_1 = sum(alpha_list[:stage_indices])
# ᾱ_k = proportion for current subtask
alpha_k = alpha_list[stage_indices]
# y_t = P_{k-1} + ᾱ_k × τ_t
y_t = P_k_minus_1 + alpha_k * tau
return float(np.clip(y_t, 0.0, 1.0))
if not isinstance(alpha, torch.Tensor):
alpha = torch.tensor(alpha, dtype=torch.float32)
# Compute cumulative_prior if not provided
if cumulative_prior is None:
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
# P_{k-1} for each predicted stage
P_k_minus_1 = cumulative_prior[stage_indices]
# ᾱ_k for each predicted stage
alpha_k = alpha[stage_indices]
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
return progress
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
current_dim = state.shape[-1]
if current_dim >= max_state_dim:
return state[..., :max_state_dim] # Truncate if larger
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0)
+1 -11
View File
@@ -230,10 +230,6 @@ def validate_visual_features_consistency(
) -> None:
"""
Validates visual feature consistency between a policy config and provided dataset/environment features.
Validation passes if EITHER:
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
Args:
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
@@ -241,11 +237,5 @@ def validate_visual_features_consistency(
"""
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
# Accept if either direction is a subset
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
if not (policy_subset_of_dataset or dataset_subset_of_policy):
if not provided_visuals.issubset(expected_visuals):
raise_feature_mismatch_error(provided_visuals, expected_visuals)
+1 -2
View File
@@ -170,9 +170,8 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
return {**pad_keys, **task_key, **index_key, **task_index_key}
def create_transition(
-154
View File
@@ -1,154 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result
+1 -33
View File
@@ -71,7 +71,7 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.factory import make_env
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
@@ -94,8 +94,6 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
@@ -167,19 +165,11 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -249,8 +239,6 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -331,8 +319,6 @@ def eval_policy(
rollout_data = rollout(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
@@ -531,16 +517,10 @@ def eval_main(cfg: EvalPipelineConfig):
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -581,8 +561,6 @@ def eval_one(
env: gym.vector.VectorEnv,
*,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -598,8 +576,6 @@ def eval_one(
task_result = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -624,8 +600,6 @@ def run_one(
env,
*,
policy,
env_preprocessor,
env_postprocessor,
preprocessor,
postprocessor,
n_episodes: int,
@@ -648,8 +622,6 @@ def run_one(
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -667,8 +639,6 @@ def run_one(
def eval_policy_all(
envs: dict[str, dict[int, gym.vector.VectorEnv]],
policy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -724,8 +694,6 @@ def eval_policy_all(
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
+4 -69
View File
@@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.factory import make_env
from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
@@ -61,7 +61,6 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
rabc_weight_computer=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -86,22 +85,10 @@ def update_policy(
"""
start_time = time.perf_counter()
policy.train()
# Compute RA-BC weights if enabled
rabc_weights = None
if rabc_weight_computer is not None:
rabc_weights = rabc_weight_computer.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# Apply RA-BC weights if enabled
if rabc_weights is not None:
# Weight the loss
loss = loss * rabc_weights.mean()
output_dict['rabc_mean_weight'] = rabc_weights.mean().item()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -153,6 +140,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
cfg.validate()
# Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
@@ -169,8 +158,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process
cfg.validate()
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
@@ -228,10 +215,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
@@ -263,28 +246,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load reward model for RA-BC if enabled
rabc_weight_computer = None
if cfg.use_rabc:
logging.info(f"Loading reward model for RA-BC from {cfg.reward_model_path}")
from lerobot.policies.factory import get_policy_class
from lerobot.utils.rabc import RABCWeightComputer
# Detect reward model type from path
# For now, assume SARM if not specified
reward_model_class = get_policy_class("sarm")
reward_model = reward_model_class.from_pretrained(cfg.reward_model_path)
reward_model.to(device)
reward_model.eval()
rabc_weight_computer = RABCWeightComputer(
reward_model=reward_model,
kappa=cfg.rabc_kappa,
epsilon=cfg.rabc_epsilon,
device=device,
)
logging.info("RA-BC weight computer initialized")
step = 0 # number of policy updates (forward + backward + optim)
@@ -298,8 +259,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
@@ -315,22 +274,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False):
# Use SARM temporal sampler for reward model training
from lerobot.datasets.temporal_sampler import SARMTemporalSampler
shuffle = False
sampler = SARMTemporalSampler(
dataset_from_index=dataset.meta.episodes["dataset_from_index"],
dataset_to_index=dataset.meta.episodes["dataset_to_index"],
frame_gap=getattr(cfg.policy, "frame_gap", 30),
shuffle=True,
seed=cfg.seed,
)
else:
shuffle = True
sampler = None
@@ -375,7 +321,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
logging.info(f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}")
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
@@ -391,7 +337,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weight_computer=rabc_weight_computer,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -408,14 +353,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weight_computer is not None:
rabc_stats = rabc_weight_computer.get_stats()
wandb_log_dict.update({
'rabc_progress_mean': rabc_stats['mean'],
'rabc_progress_std': rabc_stats['std'],
'rabc_samples_seen': rabc_stats['count'],
})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
@@ -447,8 +384,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy),
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
-12
View File
@@ -70,15 +70,3 @@ LOOKAHEAD_BACKTRACKTABLE = 100
# openpi
OPENPI_ATTENTION_MASK_VALUE = -2.3819763e38 # TODO(pepijn): Modify this when extending support to fp8 models
# Constants for LIBERO observation keys
LIBERO_KEY_EEF_POS = "robot_state/eef/pos"
LIBERO_KEY_EEF_QUAT = "robot_state/eef/quat"
LIBERO_KEY_EEF_MAT = "robot_state/eef/mat"
LIBERO_KEY_EEF_AXISANGLE = "robot_state/eef/axisangle"
LIBERO_KEY_GRIPPER_QPOS = "robot_state/gripper/qpos"
LIBERO_KEY_GRIPPER_QVEL = "robot_state/gripper/qvel"
LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
+206
View File
@@ -0,0 +1,206 @@
"""
Profiling utilities for performance analysis.
Usage:
from lerobot.utils.profiling import profile_method, get_profiling_stats, print_profiling_summary
@profile_method
def my_slow_function(x):
return x * 2
# At end of execution:
print_profiling_summary()
"""
import functools
import logging
import time
from collections import defaultdict
from threading import Lock
from typing import Any, Callable
logger = logging.getLogger(__name__)
# Global profiling statistics storage
_profiling_stats: dict[str, list[float]] = defaultdict(list)
_profiling_lock = Lock()
_profiling_enabled = False
def enable_profiling():
"""Enable profiling globally."""
global _profiling_enabled
_profiling_enabled = True
logger.info("Profiling enabled")
def disable_profiling():
"""Disable profiling globally."""
global _profiling_enabled
_profiling_enabled = False
logger.info("Profiling disabled")
def is_profiling_enabled() -> bool:
"""Check if profiling is enabled."""
return _profiling_enabled
def record_timing(name: str, duration: float):
"""Record a timing measurement.
Args:
name: Name/identifier for this timing
duration: Duration in seconds
"""
if not _profiling_enabled:
return
with _profiling_lock:
_profiling_stats[name].append(duration)
def profile_method(func: Callable) -> Callable:
"""Decorator to profile a method or function.
Args:
func: Function to profile
Returns:
Wrapped function that records execution time
"""
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
if not _profiling_enabled:
return func(*args, **kwargs)
start = time.perf_counter()
try:
result = func(*args, **kwargs)
return result
finally:
duration = time.perf_counter() - start
# Use fully qualified name
name = f"{func.__module__}.{func.__qualname__}"
record_timing(name, duration)
return wrapper
class ProfileContext:
"""Context manager for profiling code blocks.
Usage:
with ProfileContext("my_operation"):
# ... code to profile ...
"""
def __init__(self, name: str):
self.name = name
self.start = None
def __enter__(self):
if _profiling_enabled:
self.start = time.perf_counter()
return self
def __exit__(self, *args):
if _profiling_enabled and self.start is not None:
duration = time.perf_counter() - self.start
record_timing(self.name, duration)
def get_profiling_stats() -> dict[str, dict[str, float]]:
"""Get summary statistics for all profiled functions.
Returns:
Dictionary mapping function names to their stats (count, mean, min, max, total)
"""
with _profiling_lock:
summary = {}
for name, times in _profiling_stats.items():
if times:
summary[name] = {
"count": len(times),
"mean": sum(times) / len(times),
"min": min(times),
"max": max(times),
"total": sum(times),
"mean_ms": (sum(times) / len(times)) * 1000,
"min_ms": min(times) * 1000,
"max_ms": max(times) * 1000,
}
return summary
def clear_profiling_stats():
"""Clear all profiling statistics."""
with _profiling_lock:
_profiling_stats.clear()
logger.info("Profiling stats cleared")
def print_profiling_summary(sort_by: str = "total"):
"""Print formatted summary of profiling statistics.
Args:
sort_by: Sort key ('total', 'mean', 'count', 'max')
"""
summary = get_profiling_stats()
if not summary:
logger.info("No profiling data available")
return
logger.info("\n" + "=" * 100)
logger.info("PROFILING SUMMARY")
logger.info("=" * 100)
# Sort by requested key
sorted_items = sorted(summary.items(), key=lambda x: x[1].get(sort_by, 0), reverse=True)
# Print header
logger.info(
f"{'Function':<60} {'Count':>8} {'Mean (ms)':>12} {'Min (ms)':>12} {'Max (ms)':>12} {'Total (s)':>12}"
)
logger.info("-" * 100)
# Print each function's stats
for name, stats in sorted_items:
# Shorten long names
display_name = name if len(name) <= 60 else "..." + name[-57:]
logger.info(
f"{display_name:<60} "
f"{stats['count']:>8} "
f"{stats['mean_ms']:>12.2f} "
f"{stats['min_ms']:>12.2f} "
f"{stats['max_ms']:>12.2f} "
f"{stats['total']:>12.2f}"
)
logger.info("=" * 100)
# Print summary
total_time = sum(s["total"] for s in summary.values())
total_calls = sum(s["count"] for s in summary.values())
logger.info(f"\nTotal profiled time: {total_time:.2f}s across {total_calls} calls")
logger.info("=" * 100 + "\n")
def profile_section(name: str):
"""Return a context manager for profiling a code section.
Args:
name: Name for this section
Returns:
ProfileContext instance
Usage:
with profile_section("data_loading"):
data = load_data()
"""
return ProfileContext(name)
-183
View File
@@ -1,183 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reward-Aligned Behavior Cloning (RA-BC) utilities.
RA-BC uses a pre-trained reward model (e.g., SARM) to compute progress-based weights
for training samples, emphasizing high-quality demonstrations and down-weighting
suboptimal ones.
"""
import logging
import torch
import torch.nn as nn
class RABCWeightComputer:
"""
Computes RA-BC weights for training batches using a pre-trained reward model.
Uses Welford's online algorithm for numerically stable running statistics
and applies soft weighting based on progress deltas.
Args:
reward_model: Pre-trained reward model (e.g., SARM)
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
device: Device to run reward model on
"""
def __init__(
self,
reward_model: nn.Module,
kappa: float = 0.01,
epsilon: float = 1e-6,
device: torch.device = None,
):
self.reward_model = reward_model
self.reward_model.eval() # Always in eval mode
self.kappa = kappa
self.epsilon = epsilon
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Running statistics (Welford's algorithm)
self.mean = 0.0
self.m2 = 0.0
self.count = 0
logging.info(f"RA-BC WeightComputer initialized with kappa={kappa}, epsilon={epsilon}")
def _update_stats(self, deltas: torch.Tensor):
"""Update running statistics using Welford's online algorithm."""
for delta in deltas:
self.count += 1
delta_val = delta.item()
delta_mean = delta_val - self.mean
self.mean += delta_mean / self.count
delta_m2 = delta_val - self.mean
self.m2 += delta_mean * delta_m2
def _compute_weights(self, deltas: torch.Tensor) -> torch.Tensor:
"""Compute RA-BC weights from progress deltas."""
if self.count < 2:
# Not enough data, use uniform weights
return torch.ones_like(deltas)
# Get running statistics
mean = max(self.mean, 0.0) # Clamp mean to non-negative
variance = self.m2 / (self.count - 1)
std = torch.tensor(variance).sqrt().item()
# Compute soft weights
lower_bound = mean - 2 * std
upper_bound = mean + 2 * std
weights = (deltas - lower_bound) / (4 * std + self.epsilon)
weights = torch.clamp(weights, 0.0, 1.0)
# Apply hard threshold
high_quality_mask = deltas > self.kappa
weights = torch.where(high_quality_mask, torch.ones_like(weights), weights)
return weights
@torch.no_grad()
def compute_batch_weights(self, batch: dict, chunk_size: int = 1) -> torch.Tensor:
"""
Compute RA-BC weights for a training batch.
This function:
1. Extracts current and next observations from the batch
2. Computes rewards using the reward model
3. Calculates progress deltas
4. Updates running statistics
5. Returns normalized weights
Args:
batch: Training batch containing observations
chunk_size: Size of action chunks for computing deltas (default: 1)
Returns:
Weights tensor (batch_size,) normalized to sum to batch_size
"""
observation = batch.get('observation', batch)
batch_size = next(iter(observation.values())).shape[0]
# Extract features needed for reward computation
# These should already be encoded by the preprocessor
if 'video_features' not in observation or 'text_features' not in observation:
logging.warning("RA-BC: Missing video/text features, using uniform weights")
return torch.ones(batch_size, device=self.device)
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features', None)
if state_features is not None:
state_features = state_features.to(self.device)
# Compute rewards for current observations
# Handle both single-frame and multi-frame features
if video_features.dim() == 3: # (B, T, D)
# Multi-frame: use last frame reward
if hasattr(self.reward_model, 'calculate_rewards'):
current_rewards = self.reward_model.calculate_rewards(
text_features, video_features, state_features,
return_all_frames=False
)
else:
# Fallback for models without calculate_rewards
current_rewards = torch.zeros(batch_size, device=self.device)
else: # (B, D)
# Single frame
if hasattr(self.reward_model, 'calculate_rewards'):
current_rewards = self.reward_model.calculate_rewards(
text_features, video_features.unsqueeze(1), state_features,
return_all_frames=False
)
else:
current_rewards = torch.zeros(batch_size, device=self.device)
if isinstance(current_rewards, tuple):
current_rewards = current_rewards[0]
current_rewards = torch.tensor(current_rewards, device=self.device) if isinstance(current_rewards, (list, tuple)) else current_rewards
# For simplicity, assume progress delta is proportional to reward
# In practice, you'd want to compute next_frame rewards and take differences
# For now, use current reward as a proxy for progress delta
progress_deltas = current_rewards
# Update running statistics
self._update_stats(progress_deltas)
# Compute weights
weights = self._compute_weights(progress_deltas)
# Normalize weights to sum to batch_size (maintains effective batch size)
weight_sum = weights.sum() + self.epsilon
weights = weights * batch_size / weight_sum
return weights
def get_stats(self) -> dict:
"""Get current running statistics."""
std = torch.tensor(self.m2 / (self.count - 1)).sqrt().item() if self.count > 1 else 0.0
return {
'mean': self.mean,
'std': std,
'count': self.count,
}
+2 -16
View File
@@ -23,15 +23,13 @@ from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedu
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda, require_package # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_package("transformers")
@require_cuda
def test_smolvla_rtc_initialization():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy can initialize RTC processor."""
set_seed(42)
@@ -65,11 +63,8 @@ def test_smolvla_rtc_initialization():
print("✓ SmolVLA RTC initialization: Test passed")
@require_package("transformers")
@require_cuda
def test_smolvla_rtc_initialization_without_rtc_config():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy can initialize without RTC config."""
set_seed(42)
@@ -87,12 +82,9 @@ def test_smolvla_rtc_initialization_without_rtc_config():
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_with_prev_chunk():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy inference with RTC and previous chunk."""
set_seed(42)
@@ -170,12 +162,9 @@ def test_smolvla_rtc_inference_with_prev_chunk():
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_without_prev_chunk():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
@@ -244,12 +233,9 @@ def test_smolvla_rtc_inference_without_prev_chunk():
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_validation_rules():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy with RTC follows all three validation rules."""
set_seed(42)
-378
View File
@@ -1,378 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for SARM utility functions.
Tests the implementation of SARM paper formulas:
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
"""
import pytest
import numpy as np
import torch
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
from lerobot.policies.sarm.sarm_utils import (
compute_temporal_proportions,
compute_tau,
compute_cumulative_progress_batch,
)
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
return SubtaskAnnotation(
subtasks=[
Subtask(
name=name,
timestamps=Timestamp(
start=f"{start // 60:02d}:{start % 60:02d}",
end=f"{end // 60:02d}:{end % 60:02d}"
)
)
for name, start, end in subtasks
]
)
class TestComputeTemporalProportions:
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
Key insight: This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50%
# Traj 1: T=100s, subtask1=50s, subtask2=50s
# Traj 2: T=200s, subtask1=100s, subtask2=100s
annotations = {
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
}
result = compute_temporal_proportions(annotations)
# Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self):
"""Test that compute_temporal_proportions differs from naive average duration approach.
This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
"""
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
annotations = {
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
}
result = compute_temporal_proportions(annotations)
# Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self):
"""Test with a single trajectory."""
# T=100s, reach=30s, grasp=20s, lift=50s
annotations = {
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
}
result = compute_temporal_proportions(annotations)
assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6
assert abs(result['lift'] - 0.5) < 1e-6
def test_sum_to_one(self):
"""Test that proportions always sum to 1."""
# Three episodes with varying proportions
annotations = {
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
}
result = compute_temporal_proportions(annotations)
total = sum(result.values())
assert abs(total - 1.0) < 1e-6
def test_empty_annotations_returns_empty(self):
"""Test that empty annotations returns empty dict."""
result = compute_temporal_proportions({})
assert result == {}
def test_uniform_proportions(self):
"""Test with uniform proportions across subtasks."""
# Each subtask takes 25% of each episode
annotations = {
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
}
result = compute_temporal_proportions(annotations)
for name in ['a', 'b', 'c', 'd']:
assert abs(result[name] - 0.25) < 1e-6
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestComputeCumulativeProgressBatchScalar:
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
Formula: y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
"""
def test_first_subtask_start(self):
"""y should be 0 at start of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
assert y == 0.0
def test_first_subtask_end(self):
"""y should equal ᾱ_1 at end of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_start(self):
"""y should equal P_1 at start of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_end(self):
"""y should equal P_2 at end of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
def test_third_subtask_end(self):
"""y should be 1.0 at end of last subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
assert abs(y - 1.0) < 1e-6
def test_midpoint_of_subtask(self):
"""Test progress at midpoint of a subtask."""
proportions = [0.4, 0.6]
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
assert abs(y - 0.2) < 1e-6
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
assert abs(y - 0.7) < 1e-6
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
for i in range(4):
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
expected = (i + 1) * 0.25
assert abs(y - expected) < 1e-6
class TestComputeCumulativeProgressBatchTensor:
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
def test_tensor_matches_scalar_version(self):
"""Test that tensor version matches scalar version."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
test_cases = [
(0.0, 0), # start of subtask 0
(1.0, 0), # end of subtask 0
(0.0, 1), # start of subtask 1
(0.5, 1), # middle of subtask 1
(1.0, 2), # end of subtask 2
]
for tau_val, stage_idx in test_cases:
# Scalar version
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
# Tensor version (single element)
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
stages = torch.tensor([[stage_idx]]) # (1, 1)
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
assert abs(result[0, 0, 0].item() - expected) < 1e-6
def test_batch_processing(self):
"""Test batch processing with multiple samples."""
proportions = [0.4, 0.6]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(3, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
# Batch of 2 samples, sequence length 3
tau = torch.tensor([
[[0.0], [0.5], [1.0]], # sample 1
[[0.0], [0.5], [1.0]], # sample 2
])
stages = torch.tensor([
[0, 0, 0], # sample 1: all in subtask 0
[1, 1, 1], # sample 2: all in subtask 1
])
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
def test_auto_compute_cumulative_prior(self):
"""Test that cumulative_prior is auto-computed when not provided."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
tau = torch.tensor([[[0.5]]])
stages = torch.tensor([[1]])
# Without cumulative_prior (should auto-compute)
result = compute_cumulative_progress_batch(tau, stages, alpha)
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = compute_cumulative_progress_batch(tau, 0, proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (tau=1.0)
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
# Start of subtask 1 (tau=0.0)
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
# Start of subtask 2
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6
-72
View File
@@ -1,72 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from lerobot.envs.utils import preprocess_observation
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
seed = 42
np.random.seed(seed)
B = 5
obs1 = {
"pixels": {
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
},
"robot_state": {
"eef": {
"pos": np.random.randn(B, 3),
"quat": np.random.randn(B, 4),
"mat": np.random.randn(B, 3, 3),
},
"gripper": {
"qpos": np.random.randn(B, 2),
"qvel": np.random.randn(B, 2),
},
"joints": {
"pos": np.random.randn(B, 7),
"vel": np.random.randn(B, 7),
},
},
}
observation = preprocess_observation(obs1)
libero_preprocessor = PolicyProcessorPipeline(
steps=[
LiberoProcessorStep(),
]
)
processed_obs = libero_preprocessor(observation)
assert "observation.state" in processed_obs
state = processed_obs["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.dtype == torch.float32
assert state.shape[0] == B
assert state.shape[1] == 8
assert "observation.images.image" in processed_obs
assert "observation.images.image2" in processed_obs
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)