mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
[RTC] Real Time Chunking for Pi0, Smolvla, Pi0.5 (#1698)
* Add Real-Time Chunking (RTC) support for flow matching models Implement Real-Time Chunking (RTC) for action chunking policies using flow matching denoising. RTC enables smooth action transitions between consecutive chunks by using prefix guidance during denoising. Key features: - RTCProcessor class with denoise_step method for RTC guidance - Tracker system for debug tracking using time-based dictionary storage - RTCDebugVisualizer with comprehensive visualization utilities - Integration with SmolVLA policy for flow matching models - Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP) - Configurable execution horizon and max guidance weight - Example scripts for dataset evaluation and real-time control Technical details: - Uses autograd-based gradient computation for RTC corrections - Time-based tracking eliminates duplicate step issues - Proxy methods in RTCProcessor for cleaner API - Full integration with LeRobot's policy and dataset systems Files added/modified: - src/lerobot/configs/types.py: Add RTCAttentionSchedule enum - src/lerobot/policies/rtc/: Core RTC implementation - configuration_rtc.py: RTC configuration - modeling_rtc.py: RTCProcessor with denoise_step - debug_handler.py: Tracker for debug information - debug_visualizer.py: Visualization utilities - src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration - examples/rtc/: Example scripts and evaluation tools 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Fix rtc_config attribute access in SmolVLA Use getattr() to safely check for rtc_config attribute existence instead of direct attribute access. This fixes AttributeError when loading policies without rtc_config in their config. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Fix rtc_config attribute access in SmolVLA * Add RTCConfig field to SmolVLAConfig Add rtc_config as an optional field in SmolVLAConfig to properly support Real-Time Chunking configuration. This replaces the previous getattr() workarounds with direct attribute access, making the code cleaner and more maintainable. Changes: - Import RTCConfig in configuration_smolvla.py - Add rtc_config: RTCConfig | None = None field - Revert getattr() calls to direct attribute access in modeling_smolvla.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Refactor RTC enabled checks to use _rtc_enabled helper Add _rtc_enabled() helper method in VLAFlowMatching class to simplify and clean up RTC enabled checks throughout the code. This reduces code duplication and improves readability. Changes: - Add _rtc_enabled() method in VLAFlowMatching - Replace verbose rtc_config checks with _rtc_enabled() calls - Maintain exact same functionality with cleaner code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Rename track_debug method to track Simplify the method name from track_debug to just track for better readability and consistency. The method already has clear documentation about its debug tracking purpose. Changes: - Rename RTCProcessor.track_debug() to track() - Update all call sites in modeling_smolvla.py and modeling_rtc.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Use output_dir for saving all evaluation images Update eval_dataset.py to save all comparison images to the configured output_dir instead of the current directory. This provides better organization and allows users to specify where outputs should be saved. Changes: - Add os import at top level - Create output_dir at start of run_evaluation() - Save all comparison images to output_dir - Remove duplicate os imports - Update init_rtc_processor() docstring to be more concise 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Use output_dir for saving all evaluation images * Fix logging buffering and enable tracking when RTC config provided - Add force=True to logging.basicConfig to override existing configuration - Enable line buffering for stdout/stderr for real-time log output - Modify init_rtc_processor to create processor when rtc_config exists even if RTC is disabled, allowing tracking of denoising data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * Refactor SmolVLA plotting to use tracker data instead of local variables Remove local tracking variables (correction, x1_t, error) from the denoising loop and instead retrieve plotting data from the RTC tracker after each denoise step. This makes the code cleaner and uses the tracker as the single source of truth for debug/visualization data. Changes: - Remove initialization of correction, x1_t, error before denoising loop - After each Euler step, retrieve most recent debug step from tracker - Extract correction, x1_t, err from debug step for plotting - Update tracking condition to use is_debug_enabled() method 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * Move plotting logic from modeling_smolvla to eval_dataset script Refactor to improve separation of concerns: modeling_smolvla.py changes: - Remove all plotting logic from sample_actions method - Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters - Remove matplotlib and RTCDebugVisualizer imports - Remove viz_fig, viz_axs, denoise_step_counter instance variables - Simplify denoising loop to only track data in rtc_processor eval_dataset.py changes: - Add _plot_denoising_steps_from_tracker helper method - Retrieve debug steps from tracker after inference - Plot x_t, v_t, x1_t, correction, and error from tracker data - Enable debug tracking (cfg.rtc.debug = True) for visualization - Remove viz axes parameters from predict_action_chunk calls modeling_rtc.py changes: - Remove v_t from track() call (handled by user change) Benefits: - Cleaner modeling code focused on inference - Evaluation script owns all visualization logic - Better separation of concerns - Tracker is single source of truth for debug data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * Refactor plotting loging * fixup! Refactor plotting loging * Improve visualization: separate correction plot and fix axis scaling Changes: - Create separate figure for correction data instead of overlaying on v_t - Add _rescale_axes helper method to properly scale all axes - Add 10% margin to y-axis for better visualization - Fix v_t chart vertical compression issue Benefits: - Clearer v_t plot without correction overlay - Better axis scaling with proper margins - Separate correction figure for focused analysis - Improved readability of all denoising visualizations Output files: - denoising_xt_comparison.png (x_t trajectories) - denoising_vt_comparison.png (v_t velocity - now cleaner) - denoising_correction_comparison.png (NEW - separate corrections) - denoising_x1t_comparison.png (x1_t state with error) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * fixup! Improve visualization: separate correction plot and fix axis scaling * fixup! fixup! Improve visualization: separate correction plot and fix axis scaling * fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling * Fix traacking * Right kwargs for the policy * Add tests for tracker * Fix tests * Drop not required methods * Add torch compilation for eval_dataset * delete policies * Add matplotliv to dev * fixup! Add matplotliv to dev * Experiemnt with late detach * Debug * Fix compilation * Add RTC to PI0 * Pi0 * Pi0 eval dataset * fixup! Pi0 eval dataset * Turn off compilation for pi0/pi05 * fixup! Turn off compilation for pi0/pi05 * fixup! fixup! Turn off compilation for pi0/pi05 * fixup! fixup! fixup! Turn off compilation for pi0/pi05 * fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 * fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 * Add workable flow * Small fixes * Add more tests * Add validatio at the end * Update README * Silent validation * Fix tests * Add tests for modeling_rtc * Add tests for flow matching models with RTC * fixup! Add tests for flow matching models with RTC * fixup! fixup! Add tests for flow matching models with RTC * Add one more test * fixup! Add one more test * Fix test to use _rtc_enabled() instead of is_rtc_enabled() 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() * fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() * Add RTC initialization tests without config for PI0.5 and SmolVLA Add test_pi05_rtc_initialization_without_rtc_config and test_smolvla_rtc_initialization_without_rtc_config to verify that policies can initialize without RTC config and that _rtc_enabled() returns False in this case. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix PI0.5 init_rtc_processor to use getattr instead of direct model access 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix SmolVLA init_rtc_processor to use getattr instead of direct model access 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization * Fixup eval with real robot * fixup! Fixup eval with real robot * fixup! fixup! Fixup eval with real robot * Extract simulator logic from eval_with real robot and add proper headers to files * Update images * Fix tests * fixup! Fix tests * add docs for rtc * enhance doc and add images * Fix instal instructions --------- Co-authored-by: Ben Zhang <benzhangniu@gmail.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
class RTCAttentionSchedule(str, Enum):
|
||||
ZEROS = "ZEROS"
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
|
||||
@@ -19,11 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
|
||||
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI0 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI0Config):
|
||||
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI0 model
|
||||
self.model = PI0Pytorch(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
# Sample actions using the model (pass through RTC kwargs)
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
|
||||
@@ -19,11 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05Config):
|
||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI05 model
|
||||
self.model = PI05Pytorch(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Real-Time Chunking (RTC)
|
||||
|
||||
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
|
||||
|
||||
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use Real-Time Chunking in your work, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2025realtimeexecutionactionchunking,
|
||||
title={Real-Time Execution of Action Chunking Flow Policies},
|
||||
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
|
||||
year={2025},
|
||||
eprint={2506.07339},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2506.07339},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
|
||||
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action queue management for Real-Time Chunking (RTC).
|
||||
|
||||
This module provides ActionQueue, a thread-safe queue for managing action chunks
|
||||
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
|
||||
handling action merging and leftover tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionQueue:
|
||||
"""Thread-safe queue for managing action chunks in real-time control.
|
||||
|
||||
This queue handles two types of action sequences:
|
||||
- Original actions: Used for RTC to compute leftovers from previous chunks
|
||||
- Processed actions: Post-processed actions ready for robot execution
|
||||
|
||||
The queue operates in two modes:
|
||||
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
|
||||
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
|
||||
|
||||
Attributes:
|
||||
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
|
||||
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
|
||||
last_index (int): Current consumption index in the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: RTCConfig):
|
||||
"""Initialize the action queue.
|
||||
|
||||
Args:
|
||||
cfg: RTC configuration controlling queue behavior.
|
||||
"""
|
||||
self.queue = None # Processed actions for robot rollout
|
||||
self.original_queue = None # Original actions for RTC
|
||||
self.lock = Lock()
|
||||
self.last_index = 0
|
||||
self.cfg = cfg
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next action from the queue.
|
||||
|
||||
Returns:
|
||||
Tensor | None: The next action (action_dim,) or None if queue is empty.
|
||||
Returns a clone to prevent external modifications.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None or self.last_index >= len(self.queue):
|
||||
return None
|
||||
|
||||
action = self.queue[self.last_index]
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Get the number of remaining actions in the queue.
|
||||
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
These are the unconsumed actions from the current chunk, which will be
|
||||
used by RTC to compute corrections for the next chunk.
|
||||
|
||||
Returns:
|
||||
Tensor | None: Remaining original actions (remaining_steps, action_dim),
|
||||
or None if no original queue exists.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
):
|
||||
"""Merge new actions into the queue.
|
||||
|
||||
This method operates differently based on RTC mode:
|
||||
- RTC enabled: Replaces the queue, accounting for inference delay
|
||||
- RTC disabled: Appends to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy (time_steps, action_dim).
|
||||
processed_actions: Post-processed actions for robot (time_steps, action_dim).
|
||||
real_delay: Number of time steps of inference delay.
|
||||
action_index_before_inference: Index before inference started, for validation.
|
||||
"""
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
|
||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||
"""Replace the queue with new actions (RTC mode).
|
||||
|
||||
Discards the first `real_delay` actions since they correspond to the time
|
||||
spent during inference, when the robot was executing previous actions.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
real_delay: Number of time steps to skip due to inference delay.
|
||||
"""
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
|
||||
logger.debug(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.debug(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.debug(f"real_delay: {real_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||
"""Append new actions to the queue (non-RTC mode).
|
||||
|
||||
Removes already-consumed actions and appends new ones, maintaining
|
||||
queue continuity without replacement.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
"""
|
||||
if self.queue is None:
|
||||
self.original_queue = original_actions.clone()
|
||||
self.queue = processed_actions.clone()
|
||||
return
|
||||
|
||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||
self.original_queue = self.original_queue[self.last_index :]
|
||||
|
||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||
self.queue = self.queue[self.last_index :]
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
"""Validate that computed delays match expectations.
|
||||
|
||||
Compares the delay computed from inference latency with the actual
|
||||
number of actions consumed during inference.
|
||||
|
||||
Args:
|
||||
real_delay: Delay computed from inference latency.
|
||||
action_index_before_inference: Action index when inference started.
|
||||
"""
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as delay calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
|
||||
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
|
||||
|
||||
Based on:
|
||||
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCConfig:
|
||||
"""Configuration for Real Time Chunking (RTC) inference.
|
||||
|
||||
RTC improves real-time inference by treating chunk generation as an inpainting problem,
|
||||
strategically handling overlapping timesteps between action chunks using prefix attention.
|
||||
"""
|
||||
|
||||
# Infrastructure
|
||||
enabled: bool = False
|
||||
|
||||
# Core RTC settings
|
||||
# Todo change to exp
|
||||
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
|
||||
max_guidance_weight: float = 10.0
|
||||
execution_horizon: int = 10
|
||||
|
||||
# Debug settings
|
||||
debug: bool = False
|
||||
debug_maxlen: int = 100
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate RTC configuration parameters."""
|
||||
if self.max_guidance_weight <= 0:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Debug information handler for Real-Time Chunking (RTC)."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugStep:
|
||||
"""Container for debug information from a single denoising step.
|
||||
|
||||
Attributes:
|
||||
step_idx (int): Step index/counter.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
time (float | Tensor | None): Time parameter.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
metadata (dict[str, Any]): Additional metadata.
|
||||
"""
|
||||
|
||||
step_idx: int = 0
|
||||
x_t: Tensor | None = None
|
||||
v_t: Tensor | None = None
|
||||
x1_t: Tensor | None = None
|
||||
correction: Tensor | None = None
|
||||
err: Tensor | None = None
|
||||
weights: Tensor | None = None
|
||||
guidance_weight: float | Tensor | None = None
|
||||
time: float | Tensor | None = None
|
||||
inference_delay: int | None = None
|
||||
execution_horizon: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
|
||||
"""Convert debug step to dictionary.
|
||||
|
||||
Args:
|
||||
include_tensors (bool): If True, include tensor values. If False, only include
|
||||
tensor statistics (shape, mean, std, min, max).
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the debug step.
|
||||
"""
|
||||
result = {
|
||||
"step_idx": self.step_idx,
|
||||
"guidance_weight": (
|
||||
self.guidance_weight.item()
|
||||
if isinstance(self.guidance_weight, Tensor)
|
||||
else self.guidance_weight
|
||||
),
|
||||
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
|
||||
"inference_delay": self.inference_delay,
|
||||
"execution_horizon": self.execution_horizon,
|
||||
"metadata": self.metadata.copy(),
|
||||
}
|
||||
|
||||
# Add tensor information
|
||||
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
|
||||
for field_name in tensor_fields:
|
||||
tensor = getattr(self, field_name)
|
||||
if tensor is not None:
|
||||
if include_tensors:
|
||||
result[field_name] = tensor.detach().cpu()
|
||||
else:
|
||||
result[f"{field_name}_stats"] = {
|
||||
"shape": tuple(tensor.shape),
|
||||
"mean": tensor.mean().item(),
|
||||
"std": tensor.std().item(),
|
||||
"min": tensor.min().item(),
|
||||
"max": tensor.max().item(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""Collects and manages debug information for RTC processing.
|
||||
|
||||
This tracker stores debug information from recent denoising steps in a dictionary,
|
||||
using time as the key for efficient lookups and updates.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether debug collection is enabled.
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False, maxlen: int = 100):
|
||||
self.enabled = enabled
|
||||
self._steps = {} if enabled else None # Dictionary with time as key
|
||||
self._maxlen = maxlen
|
||||
self._step_counter = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded debug information."""
|
||||
if self.enabled and self._steps is not None:
|
||||
self._steps.clear()
|
||||
self._step_counter = 0
|
||||
|
||||
@torch._dynamo.disable
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Track debug information for a denoising step at a given time.
|
||||
|
||||
If a step with the given time already exists, it will be updated with the new data.
|
||||
Otherwise, a new step will be created. Only non-None fields are updated/set.
|
||||
|
||||
Note: This method is excluded from torch.compile to avoid graph breaks from
|
||||
operations like .item() which are incompatible with compiled graphs.
|
||||
|
||||
Args:
|
||||
time (float | Tensor): Time parameter - used as the key to identify the step.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction.
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
**metadata: Additional metadata to store.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Convert time to float and round to avoid float precision issues
|
||||
time_value = time.item() if isinstance(time, Tensor) else time
|
||||
time_key = round(time_value, 6) # Use rounded time as dictionary key
|
||||
|
||||
# Check if step with this time already exists
|
||||
if time_key in self._steps:
|
||||
# Update existing step with non-None fields
|
||||
existing_step = self._steps[time_key]
|
||||
if x_t is not None:
|
||||
existing_step.x_t = x_t.detach().clone()
|
||||
if v_t is not None:
|
||||
existing_step.v_t = v_t.detach().clone()
|
||||
if x1_t is not None:
|
||||
existing_step.x1_t = x1_t.detach().clone()
|
||||
if correction is not None:
|
||||
existing_step.correction = correction.detach().clone()
|
||||
if err is not None:
|
||||
existing_step.err = err.detach().clone()
|
||||
if weights is not None:
|
||||
existing_step.weights = weights.detach().clone()
|
||||
if guidance_weight is not None:
|
||||
existing_step.guidance_weight = guidance_weight
|
||||
if inference_delay is not None:
|
||||
existing_step.inference_delay = inference_delay
|
||||
if execution_horizon is not None:
|
||||
existing_step.execution_horizon = execution_horizon
|
||||
if metadata:
|
||||
existing_step.metadata.update(metadata)
|
||||
else:
|
||||
# Create new step
|
||||
step = DebugStep(
|
||||
step_idx=self._step_counter,
|
||||
x_t=x_t.detach().clone() if x_t is not None else None,
|
||||
v_t=v_t.detach().clone() if v_t is not None else None,
|
||||
x1_t=x1_t.detach().clone() if x1_t is not None else None,
|
||||
correction=correction.detach().clone() if correction is not None else None,
|
||||
err=err.detach().clone() if err is not None else None,
|
||||
weights=weights.detach().clone() if weights is not None else None,
|
||||
guidance_weight=guidance_weight,
|
||||
time=time_value,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Add to dictionary
|
||||
self._steps[time_key] = step
|
||||
self._step_counter += 1
|
||||
|
||||
# Enforce maxlen if set
|
||||
if self._maxlen is not None and len(self._steps) > self._maxlen:
|
||||
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
|
||||
oldest_key = next(iter(self._steps))
|
||||
del self._steps[oldest_key]
|
||||
|
||||
def get_all_steps(self) -> list[DebugStep]:
|
||||
"""Get all recorded debug steps.
|
||||
|
||||
Returns:
|
||||
List of all DebugStep objects (may be empty if disabled).
|
||||
"""
|
||||
if not self.enabled or self._steps is None:
|
||||
return []
|
||||
|
||||
return list(self._steps.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of recorded debug steps."""
|
||||
if not self.enabled or self._steps is None:
|
||||
return 0
|
||||
return len(self._steps)
|
||||
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Visualization utilities for RTC debug information."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RTCDebugVisualizer:
|
||||
"""Visualizer for RTC debug information.
|
||||
|
||||
This class provides methods to visualize debug information collected by the Tracker,
|
||||
including corrections, errors, weights, and guidance weights over denoising steps.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def plot_waypoints(
|
||||
axes,
|
||||
tensor,
|
||||
start_from: int = 0,
|
||||
color: str = "blue",
|
||||
label: str = "",
|
||||
alpha: float = 0.7,
|
||||
linewidth: float = 2,
|
||||
marker: str | None = None,
|
||||
markersize: int = 4,
|
||||
):
|
||||
"""Plot trajectories across multiple dimensions.
|
||||
|
||||
This function plots a tensor's values across time for multiple dimensions,
|
||||
with each dimension plotted on a separate axis.
|
||||
|
||||
Args:
|
||||
axes: Array of matplotlib axes (one for each dimension).
|
||||
tensor: The tensor to plot (can be torch.Tensor or numpy array).
|
||||
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
|
||||
start_from: Starting index for the x-axis.
|
||||
color: Color for the plot lines.
|
||||
label: Label for the plot legend.
|
||||
alpha: Transparency level for the plot.
|
||||
linewidth: Width of the plot lines.
|
||||
marker: Marker style for data points (e.g., 'o', 's', '^').
|
||||
markersize: Size of the markers.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Handle None tensor
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# Convert tensor to numpy if needed
|
||||
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
|
||||
|
||||
# Handle different tensor shapes
|
||||
if tensor_np.ndim == 3:
|
||||
# If batch dimension present, take first batch
|
||||
tensor_np = tensor_np[0]
|
||||
elif tensor_np.ndim == 1:
|
||||
# If 1D, reshape to (time_steps, 1)
|
||||
tensor_np = tensor_np.reshape(-1, 1)
|
||||
|
||||
# Get dimensions
|
||||
time_steps, num_dims = tensor_np.shape
|
||||
|
||||
# Create x-axis indices
|
||||
x_indices = np.arange(start_from, start_from + time_steps)
|
||||
|
||||
# Plot each dimension on its corresponding axis
|
||||
num_axes = len(axes) if hasattr(axes, "__len__") else 1
|
||||
for dim_idx in range(min(num_dims, num_axes)):
|
||||
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
|
||||
|
||||
# Plot the trajectory
|
||||
if marker:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
marker=marker,
|
||||
markersize=markersize,
|
||||
)
|
||||
else:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
)
|
||||
|
||||
# Add grid and labels if not already present
|
||||
if not ax.xaxis.get_label().get_text():
|
||||
ax.set_xlabel("Step", fontsize=10)
|
||||
if not ax.yaxis.get_label().get_text():
|
||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LatencyTracker:
|
||||
"""Tracks recent latencies and provides max/percentile queries.
|
||||
|
||||
Args:
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, maxlen: int = 100):
|
||||
self._values = deque(maxlen=maxlen)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded latencies."""
|
||||
self._values.clear()
|
||||
self.max_latency = 0.0
|
||||
|
||||
def add(self, latency: float) -> None:
|
||||
"""Add a latency sample (seconds)."""
|
||||
# Ensure numeric and non-negative
|
||||
val = float(latency)
|
||||
|
||||
if val < 0:
|
||||
return
|
||||
self._values.append(val)
|
||||
self.max_latency = max(self.max_latency, val)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._values)
|
||||
|
||||
def max(self) -> float | None:
|
||||
"""Return the maximum latency or None if empty."""
|
||||
return self.max_latency
|
||||
|
||||
def percentile(self, q: float) -> float | None:
|
||||
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
q = float(q)
|
||||
if q <= 0.0:
|
||||
return min(self._values)
|
||||
if q >= 1.0:
|
||||
return self.max_latency
|
||||
vals = np.array(list(self._values), dtype=np.float32)
|
||||
return float(np.quantile(vals, q))
|
||||
|
||||
def p95(self) -> float | None:
|
||||
"""Return the 95th percentile latency or None if empty."""
|
||||
return self.percentile(0.95)
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real-Time Chunking (RTC) implementation for LeRobot.
|
||||
|
||||
Based on Physical Intelligence's Kinetix implementation:
|
||||
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_tracker import Tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RTCProcessor:
|
||||
"""Real-Time Chunking processor for action chunking policies.
|
||||
|
||||
This class implements RTC techniques including velocity calculation,
|
||||
prefix attention, and adaptive chunk processing.
|
||||
"""
|
||||
|
||||
def __init__(self, rtc_config: RTCConfig):
|
||||
self.rtc_config = rtc_config
|
||||
|
||||
self.tracker = None
|
||||
|
||||
if rtc_config.debug:
|
||||
self.tracker = Tracker(
|
||||
enabled=rtc_config.debug,
|
||||
maxlen=rtc_config.debug_maxlen,
|
||||
)
|
||||
|
||||
# ====================== Tracker Proxy Methods ======================
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Proxy method to track debug information.
|
||||
|
||||
If tracker is None or disabled, this method does nothing.
|
||||
Otherwise, it forwards the call to tracker.track().
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
self.tracker.track(
|
||||
time=time,
|
||||
x_t=x_t,
|
||||
v_t=v_t,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
**metadata,
|
||||
)
|
||||
|
||||
def get_all_debug_steps(self) -> list:
|
||||
"""Get all debug steps from tracker.
|
||||
|
||||
Returns empty list if tracker is disabled or None.
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
return self.tracker.get_all_steps()
|
||||
return []
|
||||
|
||||
def is_debug_enabled(self) -> bool:
|
||||
"""Check if debug tracking is enabled.
|
||||
|
||||
Returns True if tracker exists and is enabled.
|
||||
"""
|
||||
return self.tracker is not None and self.tracker.enabled
|
||||
|
||||
def reset_tracker(self) -> None:
|
||||
"""Reset the tracker, clearing all recorded steps.
|
||||
|
||||
Does nothing if tracker is None.
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
self.tracker.reset()
|
||||
|
||||
# ====================== End Tracker Proxy Methods ======================
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
x_t,
|
||||
prev_chunk_left_over,
|
||||
inference_delay,
|
||||
time,
|
||||
original_denoise_step_partial,
|
||||
execution_horizon=None,
|
||||
) -> Tensor:
|
||||
"""RTC guidance wrapper around an existing denoiser.
|
||||
|
||||
This method wraps an original denoising callable that only takes ``x_t`` and
|
||||
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
|
||||
(RTC) prefix guidance using the leftover prefix from the previous chunk.
|
||||
|
||||
Args:
|
||||
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
|
||||
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
|
||||
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
|
||||
is applied and the method returns ``v_t`` from the original denoiser.
|
||||
inference_delay (int): Number of timesteps from the prefix to use for guidance.
|
||||
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
|
||||
broadcastable with ``x_t``.
|
||||
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
|
||||
computes the base denoised velocity given only ``x_t``.
|
||||
execution_horizon (int | None): Horizon used to build prefix weights. If
|
||||
``None``, defaults to ``self.rtc_config.execution_horizon``.
|
||||
|
||||
Returns:
|
||||
Tensor: Guided velocity with the same shape as ``v_t``.
|
||||
|
||||
Notes:
|
||||
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
|
||||
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
|
||||
right-padded with zeros to match ``T``.
|
||||
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
|
||||
and broadcast to ``(B, T, A)``.
|
||||
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
|
||||
``error = (prev_chunk_left_over - x1_t) * weights``.
|
||||
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
|
||||
|
||||
Reference:
|
||||
https://www.physicalintelligence.company/download/real_time_chunking.pdf
|
||||
"""
|
||||
|
||||
# In the original implementation, the time goes from 0 to 1 and
|
||||
# In our implementation, the time goes from 1 to 0
|
||||
# So we need to invert the time
|
||||
tau = 1 - time
|
||||
|
||||
if prev_chunk_left_over is None:
|
||||
# First step, no guidance - return v_t
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
return v_t
|
||||
|
||||
x_t = x_t.clone().detach()
|
||||
|
||||
squeezed = False
|
||||
if len(x_t.shape) < 3:
|
||||
# Add batch dimension
|
||||
x_t = x_t.unsqueeze(0)
|
||||
squeezed = True
|
||||
|
||||
if len(prev_chunk_left_over.shape) < 3:
|
||||
# Add batch dimension
|
||||
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
|
||||
|
||||
if execution_horizon is None:
|
||||
execution_horizon = self.rtc_config.execution_horizon
|
||||
|
||||
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
|
||||
# because there is nothing to merge
|
||||
if execution_horizon > prev_chunk_left_over.shape[1]:
|
||||
execution_horizon = prev_chunk_left_over.shape[1]
|
||||
|
||||
batch_size = x_t.shape[0]
|
||||
action_chunk_size = x_t.shape[1]
|
||||
action_dim = x_t.shape[2]
|
||||
|
||||
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
|
||||
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
|
||||
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
|
||||
prev_chunk_left_over = padded
|
||||
|
||||
assert prev_chunk_left_over.shape == x_t.shape, (
|
||||
"The padded previous chunk must be the same size as the input tensor"
|
||||
)
|
||||
|
||||
weights = (
|
||||
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
|
||||
.to(x_t.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
with torch.enable_grad():
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
x_t.requires_grad_(True)
|
||||
|
||||
x1_t = x_t - time * v_t # noqa: N806
|
||||
err = (prev_chunk_left_over - x1_t) * weights
|
||||
grad_outputs = err.clone().detach()
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
result = v_t - guidance_weight * correction
|
||||
|
||||
# Remove the batch dimension if it was added
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
correction = correction.squeeze(0)
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_prefix_weights(self, start, end, total):
|
||||
start = min(start, end)
|
||||
|
||||
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
|
||||
weights = torch.zeros(total)
|
||||
weights[:start] = 1.0
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
|
||||
weights = torch.ones(total)
|
||||
weights[end:] = 0.0
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
|
||||
lin_weights = self._linweights(start, end, total)
|
||||
weights = self._add_trailing_zeros(lin_weights, total, end)
|
||||
weights = self._add_leading_ones(weights, start, total)
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
|
||||
lin_weights = self._linweights(start, end, total)
|
||||
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
|
||||
weights = self._add_trailing_zeros(lin_weights, total, end)
|
||||
weights = self._add_leading_ones(weights, start, total)
|
||||
|
||||
return weights
|
||||
|
||||
def _linweights(self, start, end, total):
|
||||
skip_steps_at_end = max(total - end, 0)
|
||||
|
||||
linspace_steps = total - skip_steps_at_end - start
|
||||
|
||||
if end <= start or linspace_steps <= 0:
|
||||
return torch.tensor([])
|
||||
|
||||
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
|
||||
|
||||
def _add_trailing_zeros(self, weights, total, end):
|
||||
zeros_len = total - end
|
||||
|
||||
if zeros_len <= 0:
|
||||
return weights
|
||||
|
||||
zeros = torch.zeros(zeros_len)
|
||||
return torch.cat([weights, zeros])
|
||||
|
||||
def _add_leading_ones(self, weights, start, total):
|
||||
ones_len = min(start, total)
|
||||
|
||||
if ones_len <= 0:
|
||||
return weights
|
||||
|
||||
ones = torch.ones(ones_len)
|
||||
return torch.cat([ones, weights])
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Lets create processor if the config provided
|
||||
# If RTC is not enabled - we still can track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
# In case of calling init_rtc_processor after the model is created
|
||||
# We need to set the rtc_processor to the model
|
||||
# During the normal initialization process the model is not created yet
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def _get_action_chunk(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
# TODO: Check if this for loop is needed.
|
||||
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
|
||||
# In the case of offline inference, we have the action in the batch
|
||||
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def predict_action_chunk(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
actions = self._get_action_chunk(batch, noise, **kwargs)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
if self._check_get_actions_condition():
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
|
||||
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def _check_get_actions_condition(self) -> bool:
|
||||
return len(self._queues[ACTION]) == 0
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: SmolVLAConfig):
|
||||
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
device=self.config.device,
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
Reference in New Issue
Block a user