[RTC] Real Time Chunking for Pi0, Smolvla, Pi0.5 (#1698)

* Add Real-Time Chunking (RTC) support for flow matching models

Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Fix rtc_config attribute access in SmolVLA

Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix rtc_config attribute access in SmolVLA

* Add RTCConfig field to SmolVLAConfig

Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Refactor RTC enabled checks to use _rtc_enabled helper

Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Rename track_debug method to track

Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Use output_dir for saving all evaluation images

Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Use output_dir for saving all evaluation images

* Fix logging buffering and enable tracking when RTC config provided

- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor SmolVLA plotting to use tracker data instead of local variables

Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Move plotting logic from modeling_smolvla to eval_dataset script

Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor plotting loging

* fixup! Refactor plotting loging

* Improve visualization: separate correction plot and fix axis scaling

Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* Fix traacking

* Right kwargs for the policy

* Add tests for tracker

* Fix tests

* Drop not required methods

* Add torch compilation for eval_dataset

* delete policies

* Add matplotliv to dev

* fixup! Add matplotliv to dev

* Experiemnt with late detach

* Debug

* Fix compilation

* Add RTC to PI0

* Pi0

* Pi0 eval dataset

* fixup! Pi0 eval dataset

* Turn off compilation for pi0/pi05

* fixup! Turn off compilation for pi0/pi05

* fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* Add workable flow

* Small fixes

* Add more tests

* Add validatio at the end

* Update README

* Silent validation

* Fix tests

* Add tests for modeling_rtc

* Add tests for flow matching models with RTC

* fixup! Add tests for flow matching models with RTC

* fixup! fixup! Add tests for flow matching models with RTC

* Add one more test

* fixup! Add one more test

* Fix test to use _rtc_enabled() instead of is_rtc_enabled()

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* Add RTC initialization tests without config for PI0.5 and SmolVLA

Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix SmolVLA init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

* Fixup eval with real robot

* fixup! Fixup eval with real robot

* fixup! fixup! Fixup eval with real robot

* Extract simulator logic from eval_with real robot and add proper headers to files

* Update images

* Fix tests

* fixup! Fix tests

* add docs for rtc

* enhance doc and add images

* Fix instal instructions

---------
Co-authored-by: Ben Zhang <benzhangniu@gmail.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Eugene Mironov
2025-11-19 17:19:48 +07:00
committed by GitHub
parent b464d9f8bc
commit 8a915c6b6f
26 changed files with 6517 additions and 50 deletions
+7
View File
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
class PolicyFeature:
type: FeatureType
shape: tuple[int, ...]
class RTCAttentionSchedule(str, Enum):
ZEROS = "ZEROS"
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -15
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0 PyTorch model."""
def __init__(self, config: PI0Config):
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI0 model
self.model = PI0Pytorch(config)
self.init_rtc_processor()
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
# Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
# Sample actions using the model (pass through RTC kwargs)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@PreTrainedConfig.register_subclass("pi05")
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -14
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model."""
def __init__(self, config: PI05Config):
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI05 model
self.model = PI05Pytorch(config)
self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
+38
View File
@@ -0,0 +1,38 @@
# Real-Time Chunking (RTC)
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
---
## Citation
If you use Real-Time Chunking in your work, please cite:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2025realtimeexecutionactionchunking,
title={Real-Time Execution of Action Chunking Flow Policies},
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
year={2025},
eprint={2506.07339},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2506.07339},
}
```
---
## License
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
+219
View File
@@ -0,0 +1,219 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action queue management for Real-Time Chunking (RTC).
This module provides ActionQueue, a thread-safe queue for managing action chunks
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
handling action merging and leftover tracking.
"""
import logging
from threading import Lock
import torch
from torch import Tensor
from lerobot.policies.rtc.configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
class ActionQueue:
"""Thread-safe queue for managing action chunks in real-time control.
This queue handles two types of action sequences:
- Original actions: Used for RTC to compute leftovers from previous chunks
- Processed actions: Post-processed actions ready for robot execution
The queue operates in two modes:
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
Args:
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
Attributes:
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
last_index (int): Current consumption index in the queue.
"""
def __init__(self, cfg: RTCConfig):
"""Initialize the action queue.
Args:
cfg: RTC configuration controlling queue behavior.
"""
self.queue = None # Processed actions for robot rollout
self.original_queue = None # Original actions for RTC
self.lock = Lock()
self.last_index = 0
self.cfg = cfg
def get(self) -> Tensor | None:
"""Get the next action from the queue.
Returns:
Tensor | None: The next action (action_dim,) or None if queue is empty.
Returns a clone to prevent external modifications.
"""
with self.lock:
if self.queue is None or self.last_index >= len(self.queue):
return None
action = self.queue[self.last_index]
self.last_index += 1
return action.clone()
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
These are the unconsumed actions from the current chunk, which will be
used by RTC to compute corrections for the next chunk.
Returns:
Tensor | None: Remaining original actions (remaining_steps, action_dim),
or None if no original queue exists.
"""
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
):
"""Merge new actions into the queue.
This method operates differently based on RTC mode:
- RTC enabled: Replaces the queue, accounting for inference delay
- RTC disabled: Appends to the queue, maintaining continuity
Args:
original_actions: Unprocessed actions from policy (time_steps, action_dim).
processed_actions: Post-processed actions for robot (time_steps, action_dim).
real_delay: Number of time steps of inference delay.
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
return
self._append_actions_queue(original_actions, processed_actions)
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
"""Replace the queue with new actions (RTC mode).
Discards the first `real_delay` actions since they correspond to the time
spent during inference, when the robot was executing previous actions.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
self.last_index = 0
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
"""Append new actions to the queue (non-RTC mode).
Removes already-consumed actions and appends new ones, maintaining
queue continuity without replacement.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
"""
if self.queue is None:
self.original_queue = original_actions.clone()
self.queue = processed_actions.clone()
return
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
self.original_queue = self.original_queue[self.last_index :]
self.queue = torch.cat([self.queue, processed_actions.clone()])
self.queue = self.queue[self.last_index :]
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
number of actions consumed during inference.
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
Based on:
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
"""
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule
@dataclass
class RTCConfig:
"""Configuration for Real Time Chunking (RTC) inference.
RTC improves real-time inference by treating chunk generation as an inpainting problem,
strategically handling overlapping timesteps between action chunks using prefix attention.
"""
# Infrastructure
enabled: bool = False
# Core RTC settings
# Todo change to exp
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
max_guidance_weight: float = 10.0
execution_horizon: int = 10
# Debug settings
debug: bool = False
debug_maxlen: int = 100
def __post_init__(self):
"""Validate RTC configuration parameters."""
if self.max_guidance_weight <= 0:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
+233
View File
@@ -0,0 +1,233 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Debug information handler for Real-Time Chunking (RTC)."""
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
@dataclass
class DebugStep:
"""Container for debug information from a single denoising step.
Attributes:
step_idx (int): Step index/counter.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
time (float | Tensor | None): Time parameter.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
metadata (dict[str, Any]): Additional metadata.
"""
step_idx: int = 0
x_t: Tensor | None = None
v_t: Tensor | None = None
x1_t: Tensor | None = None
correction: Tensor | None = None
err: Tensor | None = None
weights: Tensor | None = None
guidance_weight: float | Tensor | None = None
time: float | Tensor | None = None
inference_delay: int | None = None
execution_horizon: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
"""Convert debug step to dictionary.
Args:
include_tensors (bool): If True, include tensor values. If False, only include
tensor statistics (shape, mean, std, min, max).
Returns:
Dictionary representation of the debug step.
"""
result = {
"step_idx": self.step_idx,
"guidance_weight": (
self.guidance_weight.item()
if isinstance(self.guidance_weight, Tensor)
else self.guidance_weight
),
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
"inference_delay": self.inference_delay,
"execution_horizon": self.execution_horizon,
"metadata": self.metadata.copy(),
}
# Add tensor information
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
for field_name in tensor_fields:
tensor = getattr(self, field_name)
if tensor is not None:
if include_tensors:
result[field_name] = tensor.detach().cpu()
else:
result[f"{field_name}_stats"] = {
"shape": tuple(tensor.shape),
"mean": tensor.mean().item(),
"std": tensor.std().item(),
"min": tensor.min().item(),
"max": tensor.max().item(),
}
return result
class Tracker:
"""Collects and manages debug information for RTC processing.
This tracker stores debug information from recent denoising steps in a dictionary,
using time as the key for efficient lookups and updates.
Args:
enabled (bool): Whether debug collection is enabled.
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
"""
def __init__(self, enabled: bool = False, maxlen: int = 100):
self.enabled = enabled
self._steps = {} if enabled else None # Dictionary with time as key
self._maxlen = maxlen
self._step_counter = 0
def reset(self) -> None:
"""Clear all recorded debug information."""
if self.enabled and self._steps is not None:
self._steps.clear()
self._step_counter = 0
@torch._dynamo.disable
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Track debug information for a denoising step at a given time.
If a step with the given time already exists, it will be updated with the new data.
Otherwise, a new step will be created. Only non-None fields are updated/set.
Note: This method is excluded from torch.compile to avoid graph breaks from
operations like .item() which are incompatible with compiled graphs.
Args:
time (float | Tensor): Time parameter - used as the key to identify the step.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction.
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
**metadata: Additional metadata to store.
"""
if not self.enabled:
return
# Convert time to float and round to avoid float precision issues
time_value = time.item() if isinstance(time, Tensor) else time
time_key = round(time_value, 6) # Use rounded time as dictionary key
# Check if step with this time already exists
if time_key in self._steps:
# Update existing step with non-None fields
existing_step = self._steps[time_key]
if x_t is not None:
existing_step.x_t = x_t.detach().clone()
if v_t is not None:
existing_step.v_t = v_t.detach().clone()
if x1_t is not None:
existing_step.x1_t = x1_t.detach().clone()
if correction is not None:
existing_step.correction = correction.detach().clone()
if err is not None:
existing_step.err = err.detach().clone()
if weights is not None:
existing_step.weights = weights.detach().clone()
if guidance_weight is not None:
existing_step.guidance_weight = guidance_weight
if inference_delay is not None:
existing_step.inference_delay = inference_delay
if execution_horizon is not None:
existing_step.execution_horizon = execution_horizon
if metadata:
existing_step.metadata.update(metadata)
else:
# Create new step
step = DebugStep(
step_idx=self._step_counter,
x_t=x_t.detach().clone() if x_t is not None else None,
v_t=v_t.detach().clone() if v_t is not None else None,
x1_t=x1_t.detach().clone() if x1_t is not None else None,
correction=correction.detach().clone() if correction is not None else None,
err=err.detach().clone() if err is not None else None,
weights=weights.detach().clone() if weights is not None else None,
guidance_weight=guidance_weight,
time=time_value,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
metadata=metadata,
)
# Add to dictionary
self._steps[time_key] = step
self._step_counter += 1
# Enforce maxlen if set
if self._maxlen is not None and len(self._steps) > self._maxlen:
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
oldest_key = next(iter(self._steps))
del self._steps[oldest_key]
def get_all_steps(self) -> list[DebugStep]:
"""Get all recorded debug steps.
Returns:
List of all DebugStep objects (may be empty if disabled).
"""
if not self.enabled or self._steps is None:
return []
return list(self._steps.values())
def __len__(self) -> int:
"""Return the number of recorded debug steps."""
if not self.enabled or self._steps is None:
return 0
return len(self._steps)
@@ -0,0 +1,113 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Visualization utilities for RTC debug information."""
import torch
class RTCDebugVisualizer:
"""Visualizer for RTC debug information.
This class provides methods to visualize debug information collected by the Tracker,
including corrections, errors, weights, and guidance weights over denoising steps.
"""
@staticmethod
def plot_waypoints(
axes,
tensor,
start_from: int = 0,
color: str = "blue",
label: str = "",
alpha: float = 0.7,
linewidth: float = 2,
marker: str | None = None,
markersize: int = 4,
):
"""Plot trajectories across multiple dimensions.
This function plots a tensor's values across time for multiple dimensions,
with each dimension plotted on a separate axis.
Args:
axes: Array of matplotlib axes (one for each dimension).
tensor: The tensor to plot (can be torch.Tensor or numpy array).
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
start_from: Starting index for the x-axis.
color: Color for the plot lines.
label: Label for the plot legend.
alpha: Transparency level for the plot.
linewidth: Width of the plot lines.
marker: Marker style for data points (e.g., 'o', 's', '^').
markersize: Size of the markers.
"""
import numpy as np
# Handle None tensor
if tensor is None:
return
# Convert tensor to numpy if needed
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
# Handle different tensor shapes
if tensor_np.ndim == 3:
# If batch dimension present, take first batch
tensor_np = tensor_np[0]
elif tensor_np.ndim == 1:
# If 1D, reshape to (time_steps, 1)
tensor_np = tensor_np.reshape(-1, 1)
# Get dimensions
time_steps, num_dims = tensor_np.shape
# Create x-axis indices
x_indices = np.arange(start_from, start_from + time_steps)
# Plot each dimension on its corresponding axis
num_axes = len(axes) if hasattr(axes, "__len__") else 1
for dim_idx in range(min(num_dims, num_axes)):
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
# Plot the trajectory
if marker:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
marker=marker,
markersize=markersize,
)
else:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
)
# Add grid and labels if not already present
if not ax.xaxis.get_label().get_text():
ax.set_xlabel("Step", fontsize=10)
if not ax.yaxis.get_label().get_text():
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
from collections import deque
import numpy as np
class LatencyTracker:
"""Tracks recent latencies and provides max/percentile queries.
Args:
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
"""
def __init__(self, maxlen: int = 100):
self._values = deque(maxlen=maxlen)
self.reset()
def reset(self) -> None:
"""Clear all recorded latencies."""
self._values.clear()
self.max_latency = 0.0
def add(self, latency: float) -> None:
"""Add a latency sample (seconds)."""
# Ensure numeric and non-negative
val = float(latency)
if val < 0:
return
self._values.append(val)
self.max_latency = max(self.max_latency, val)
def __len__(self) -> int:
return len(self._values)
def max(self) -> float | None:
"""Return the maximum latency or None if empty."""
return self.max_latency
def percentile(self, q: float) -> float | None:
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
if not self._values:
return 0.0
q = float(q)
if q <= 0.0:
return min(self._values)
if q >= 1.0:
return self.max_latency
vals = np.array(list(self._values), dtype=np.float32)
return float(np.quantile(vals, q))
def p95(self) -> float | None:
"""Return the 95th percentile latency or None if empty."""
return self.percentile(0.95)
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real-Time Chunking (RTC) implementation for LeRobot.
Based on Physical Intelligence's Kinetix implementation:
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
"""
import logging
import math
import torch
from torch import Tensor
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_tracker import Tracker
logger = logging.getLogger(__name__)
class RTCProcessor:
"""Real-Time Chunking processor for action chunking policies.
This class implements RTC techniques including velocity calculation,
prefix attention, and adaptive chunk processing.
"""
def __init__(self, rtc_config: RTCConfig):
self.rtc_config = rtc_config
self.tracker = None
if rtc_config.debug:
self.tracker = Tracker(
enabled=rtc_config.debug,
maxlen=rtc_config.debug_maxlen,
)
# ====================== Tracker Proxy Methods ======================
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Proxy method to track debug information.
If tracker is None or disabled, this method does nothing.
Otherwise, it forwards the call to tracker.track().
"""
if self.tracker is not None:
self.tracker.track(
time=time,
x_t=x_t,
v_t=v_t,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
**metadata,
)
def get_all_debug_steps(self) -> list:
"""Get all debug steps from tracker.
Returns empty list if tracker is disabled or None.
"""
if self.tracker is not None:
return self.tracker.get_all_steps()
return []
def is_debug_enabled(self) -> bool:
"""Check if debug tracking is enabled.
Returns True if tracker exists and is enabled.
"""
return self.tracker is not None and self.tracker.enabled
def reset_tracker(self) -> None:
"""Reset the tracker, clearing all recorded steps.
Does nothing if tracker is None.
"""
if self.tracker is not None:
self.tracker.reset()
# ====================== End Tracker Proxy Methods ======================
def denoise_step(
self,
x_t,
prev_chunk_left_over,
inference_delay,
time,
original_denoise_step_partial,
execution_horizon=None,
) -> Tensor:
"""RTC guidance wrapper around an existing denoiser.
This method wraps an original denoising callable that only takes ``x_t`` and
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
(RTC) prefix guidance using the leftover prefix from the previous chunk.
Args:
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
is applied and the method returns ``v_t`` from the original denoiser.
inference_delay (int): Number of timesteps from the prefix to use for guidance.
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
broadcastable with ``x_t``.
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
computes the base denoised velocity given only ``x_t``.
execution_horizon (int | None): Horizon used to build prefix weights. If
``None``, defaults to ``self.rtc_config.execution_horizon``.
Returns:
Tensor: Guided velocity with the same shape as ``v_t``.
Notes:
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
right-padded with zeros to match ``T``.
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
and broadcast to ``(B, T, A)``.
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
``error = (prev_chunk_left_over - x1_t) * weights``.
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
Reference:
https://www.physicalintelligence.company/download/real_time_chunking.pdf
"""
# In the original implementation, the time goes from 0 to 1 and
# In our implementation, the time goes from 1 to 0
# So we need to invert the time
tau = 1 - time
if prev_chunk_left_over is None:
# First step, no guidance - return v_t
v_t = original_denoise_step_partial(x_t)
return v_t
x_t = x_t.clone().detach()
squeezed = False
if len(x_t.shape) < 3:
# Add batch dimension
x_t = x_t.unsqueeze(0)
squeezed = True
if len(prev_chunk_left_over.shape) < 3:
# Add batch dimension
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
if execution_horizon is None:
execution_horizon = self.rtc_config.execution_horizon
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
# because there is nothing to merge
if execution_horizon > prev_chunk_left_over.shape[1]:
execution_horizon = prev_chunk_left_over.shape[1]
batch_size = x_t.shape[0]
action_chunk_size = x_t.shape[1]
action_dim = x_t.shape[2]
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
prev_chunk_left_over = padded
assert prev_chunk_left_over.shape == x_t.shape, (
"The padded previous chunk must be the same size as the input tensor"
)
weights = (
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
.to(x_t.device)
.unsqueeze(0)
.unsqueeze(-1)
)
with torch.enable_grad():
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
x1_t = x_t - time * v_t # noqa: N806
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
result = v_t - guidance_weight * correction
# Remove the batch dimension if it was added
if squeezed:
result = result.squeeze(0)
correction = correction.squeeze(0)
x1_t = x1_t.squeeze(0)
err = err.squeeze(0)
self.track(
time=time,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
return result
def get_prefix_weights(self, start, end, total):
start = min(start, end)
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
weights = torch.zeros(total)
weights[:start] = 1.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
weights = torch.ones(total)
weights[end:] = 0.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
lin_weights = self._linweights(start, end, total)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
lin_weights = self._linweights(start, end, total)
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
return weights
def _linweights(self, start, end, total):
skip_steps_at_end = max(total - end, 0)
linspace_steps = total - skip_steps_at_end - start
if end <= start or linspace_steps <= 0:
return torch.tensor([])
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
def _add_trailing_zeros(self, weights, total, end):
zeros_len = total - end
if zeros_len <= 0:
return weights
zeros = torch.zeros(zeros_len)
return torch.cat([weights, zeros])
def _add_leading_ones(self, weights, start, total):
ones_len = min(start, total)
if ones_len <= 0:
return weights
ones = torch.ones(ones_len)
return torch.cat([ones, weights])
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
def __post_init__(self):
super().__post_init__()
+101 -19
View File
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.utils.utils import get_safe_dtype
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.model = VLAFlowMatching(config)
self.init_rtc_processor()
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
self.reset()
def reset(self):
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Lets create processor if the config provided
# If RTC is not enabled - we still can track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# In case of calling init_rtc_processor after the model is created
# We need to set the rtc_processor to the model
# During the normal initialization process the model is not created yet
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def _get_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def predict_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
actions = self._get_action_chunk(batch, noise, **kwargs)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
if self._check_get_actions_condition():
actions = self._get_action_chunk(batch, noise)
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
return self._queues[ACTION].popleft()
def _check_get_actions_condition(self) -> bool:
return len(self._queues[ACTION]) == 0
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
"""
def __init__(self, config: SmolVLAConfig):
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(