diff --git a/examples/rtc/README.md b/examples/rtc/README.md new file mode 100644 index 000000000..988f997a9 --- /dev/null +++ b/examples/rtc/README.md @@ -0,0 +1,281 @@ +# Real-Time Chunking (RTC) Examples + +This directory contains examples and evaluation scripts for Real-Time Chunking (RTC), a technique for improving action chunking policies in real-time robot control. + +## Overview + +Real-Time Chunking addresses the challenge of maintaining consistency and reactivity when using action chunking policies with non-negligible inference latency. It uses a guidance technique during diffusion sampling to blend new action predictions with previously planned actions. + +**Key Benefits:** + +- Maintains consistency between consecutive action chunks +- Reduces jitter and improves smoothness +- Adapts to inference delays dynamically + +**Reference:** [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/download/real_time_chunking.pdf) + +## Scripts + +### 1. `real_time_chunking_evaluate.py` + +Real-time evaluation on physical robots or simulation environments. + +**Features:** + +- Run policy with RTC on real robot or simulation +- Compare RTC vs non-RTC actions in real-time +- Multi-threaded action execution and inference +- Support for torch.compile() optimization + +**Usage:** + +```bash +# With real robot +uv run python examples/rtc/real_time_chunking_evaluate.py \ + --policy.path=lerobot/smolvla_base \ + --robot.type=so100 \ + --task="pick up the cup" + +# With simulation environment +uv run python examples/rtc/real_time_chunking_evaluate.py \ + --policy.path=lerobot/smolvla_base \ + --env.type=pusht \ + --duration=60.0 + +# Disable verbose comparison (faster) +uv run python examples/rtc/real_time_chunking_evaluate.py \ + --policy.path=lerobot/smolvla_base \ + --robot.type=so100 \ + --verbose_rtc_comparison=false + +# With policy compilation (CUDA only, not MPS) +uv run python examples/rtc/real_time_chunking_evaluate.py \ + --policy.path=lerobot/smolvla_base \ + --robot.type=so100 \ + --compile_policy=true \ + --compile_mode=max-autotune +``` + +**Key Parameters:** + +- `--policy.path`: Path to pretrained policy +- `--robot.type` or `--env.type`: Robot or environment to use +- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10) +- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0) +- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP) +- `--verbose_rtc_comparison`: Enable detailed RTC comparison logging (default: true) +- `--duration`: How long to run (seconds, default: 30.0) +- `--fps`: Action execution frequency (Hz, default: 10.0) + +### 2. `evaluate_rtc_on_dataset.py` + +Offline evaluation on dataset samples to measure RTC effectiveness. + +**Features:** + +- Evaluate RTC on dataset without running robot +- Compare RTC vs non-RTC predictions +- Measure consistency and ground truth alignment +- Simulate different inference delays +- Save detailed metrics to JSON + +**Usage:** + +```bash +# Basic evaluation +uv run python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=lerobot/pusht \ + --num_iterations=100 + +# Simulate inference delay (every 3rd step) +uv run python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=lerobot/pusht \ + --num_iterations=200 \ + --skip_steps=3 + +# Custom RTC configuration +uv run python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=lerobot/pusht \ + --num_iterations=100 \ + --rtc.execution_horizon=12 \ + --rtc.max_guidance_weight=5.0 \ + --rtc.prefix_attention_schedule=LINEAR + +# Save results to file +uv run python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=lerobot/pusht \ + --num_iterations=100 \ + --output_path=results/rtc_evaluation.json + +# Verbose mode with detailed logging +uv run python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=lerobot/pusht \ + --num_iterations=50 \ + --verbose=true +``` + +**Key Parameters:** + +- `--policy.path`: Path to pretrained policy +- `--dataset.repo_id`: Dataset to evaluate on +- `--num_iterations`: Number of samples to evaluate (default: 100) +- `--skip_steps`: Steps to skip between inferences, simulates inference delay (default: 1) +- `--start_episode`: Episode to start from (default: 0) +- `--output_path`: Path to save results JSON +- `--verbose`: Enable detailed per-sample logging +- `--device`: Device to use (cuda, cpu, mps, auto) + +**Metrics Reported:** + +- **RTC vs Ground Truth MSE**: How close RTC predictions are to actual actions +- **No-RTC vs Ground Truth MSE**: Baseline without RTC +- **RTC Improvement**: Absolute and relative improvement over baseline +- **RTC Consistency**: How well RTC maintains consistency in prefix region + - Prefix MSE + - Mean/Max error in overlap region + +### 3. `run_dataset_evaluation.sh` + +Convenience script with multiple evaluation scenarios. + +**Usage:** + +```bash +# Edit the script to set your policy and dataset +# Then run all examples: +./examples/rtc/run_dataset_evaluation.sh + +# Or run individual examples from the script +``` + +## Understanding RTC Parameters + +### `execution_horizon` + +Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity. + +**Typical values:** 8-12 steps + +### `max_guidance_weight` + +Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions. + +**Typical values:** 1.0-10.0 + +### `prefix_attention_schedule` + +How to weight consistency across the overlap region: + +- `ZEROS`: Binary (full weight up to inference_delay, then zero) +- `ONES`: Full weight across entire execution_horizon +- `LINEAR`: Linear decay from inference_delay to execution_horizon +- `EXP`: Exponential decay (recommended) + +**Recommended:** `EXP` + +### `skip_steps` (evaluation only) + +Simulates inference delay by evaluating every N-th step. This helps understand how RTC performs with realistic delays. + +**Example:** `skip_steps=3` means policy infers every 3 steps, simulating 3x action execution frequency vs inference frequency. + +## Output Format (Dataset Evaluation) + +When using `--output_path`, results are saved in JSON format: + +```json +{ + "summary": { + "rtc_vs_ground_truth_mse": { + "mean": 0.00123, + "std": 0.00045, + "min": 0.00012, + "max": 0.00456 + }, + "improvement": { + "absolute": 0.00034, + "relative_percent": 12.5 + }, + ... + }, + "config": { + "num_iterations": 100, + "skip_steps": 3, + "execution_horizon": 10, + ... + }, + "detailed_results": [ + { + "sample_idx": 0, + "rtc_vs_ground_truth_mse": 0.00112, + "no_rtc_vs_ground_truth_mse": 0.00145, + ... + }, + ... + ] +} +``` + +## Tips + +1. **Start with dataset evaluation** to understand RTC behavior before running on robot +2. **Use verbose mode** for debugging unexpected behavior +3. **Tune execution_horizon** based on your inference latency and action frequency +4. **Monitor consistency metrics** - very low consistency might indicate execution_horizon is too small +5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable + +## Troubleshooting + +### High RTC vs No-RTC difference but no improvement + +- Try reducing `max_guidance_weight` +- Check if `execution_horizon` is too large + +### Poor consistency metrics + +- Increase `execution_horizon` +- Check that `skip_steps` is not larger than your action chunk size +- Verify episodes are being reset correctly + +### RTC worse than No-RTC + +- RTC may not help if inference is faster than action execution +- Try different `prefix_attention_schedule` +- Ensure `execution_horizon` matches your use case + +## Examples Results + +Example output from dataset evaluation: + +``` +================================================================================ +EVALUATION SUMMARY +================================================================================ + +Ground Truth Alignment: + RTC MSE: 0.001234 ± 0.000456 + No-RTC MSE: 0.001567 ± 0.000512 + +RTC Improvement: + Absolute: 0.000333 + Relative: 21.23% + +RTC vs No-RTC Difference: + MSE: 0.000112 ± 0.000034 + +RTC Consistency (Prefix Region): + MSE: 0.000089 ± 0.000023 + Mean Error: 0.007654 ± 0.002341 + Max Error: 0.023456 ± 0.008765 +``` + +## Related Documentation + +- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py) +- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py) +- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf) diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py new file mode 100644 index 000000000..71e2304d2 --- /dev/null +++ b/examples/rtc/eval_dataset.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python + +""" +Evaluate Real-Time Chunking (RTC) performance on dataset samples. + +This script takes two random samples from a dataset: +- Uses actions from the first sample as previous chunk +- Generates new actions for the second sample with and without RTC + +It compares action predictions with and without RTC on dataset samples, +measuring consistency and ground truth alignment. + +Usage: + python eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --rtc.execution_horizon=8 \ + --device=mps +""" + +import logging +import random +from dataclasses import dataclass, field + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from lerobot.configs import parser +from lerobot.configs.default import DatasetConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer +from lerobot.utils.hub import HubMixin + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def set_seed(seed: int): + """Set random seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if torch.backends.mps.is_available(): + torch.mps.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + logger.info(f"Random seed set to: {seed}") + + +@dataclass +class RTCEvalConfig(HubMixin): + """Configuration for RTC evaluation.""" + + # Policy configuration + policy: PreTrainedConfig | None = None + + # Dataset configuration + dataset: DatasetConfig = field(default_factory=DatasetConfig) + + # RTC configuration + rtc: RTCConfig = field( + default_factory=lambda: RTCConfig( + enabled=True, + execution_horizon=20, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=True, + debug_maxlen=1000, + ) + ) + + # Device configuration + device: str | None = field( + default=None, + metadata={"help": "Device to run on (cuda, cpu, mps, auto)"}, + ) + + # Output configuration + output_dir: str = field( + default="rtc_debug_output", + metadata={"help": "Directory to save debug visualizations"}, + ) + verbose: bool = field( + default=False, + metadata={"help": "Enable verbose logging"}, + ) + enable_debug_viz: bool = field( + default=True, + metadata={"help": "Enable debug visualization"}, + ) + + # Seed configuration + seed: int = field( + default=42, + metadata={"help": "Random seed for reproducibility"}, + ) + + inference_delay: int = field( + default=4, + metadata={"help": "Inference delay for RTC"}, + ) + + def __post_init__(self): + # Parse policy path + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + else: + raise ValueError("Policy path is required (--policy.path)") + + # Auto-detect device if not specified + if self.device is None or self.device == "auto": + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.backends.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" + logger.info(f"Auto-detected device: {self.device}") + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + + +class RTCEvaluator: + """Evaluator for RTC on dataset samples.""" + + def __init__(self, cfg: RTCEvalConfig): + self.cfg = cfg + self.device = cfg.device + + # Load policy + logger.info(f"Loading policy from {cfg.policy.pretrained_path}") + policy_class = get_policy_class(cfg.policy.type) + self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path) + self.policy = self.policy.to(self.device) + self.policy.eval() + + # Configure RTC + cfg.rtc.enabled = True + self.policy.config.rtc_config = cfg.rtc + self.policy.init_rtc_processor(verbose=cfg.verbose) + + logger.info(f"Policy loaded: {self.policy.name}") + logger.info(f"RTC enabled: {cfg.rtc.enabled}") + logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}") + + # Load dataset + logger.info(f"Loading dataset: {cfg.dataset.repo_id}") + self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) + logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") + + # Create preprocessor/postprocessor + self.preprocessor, self.postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + preprocessor_overrides={ + "device_processor": {"device": self.device}, + }, + ) + + def run_evaluation(self): + """Run evaluation on two random dataset samples.""" + logger.info("Starting RTC evaluation") + logger.info(f"Inference delay: {self.cfg.inference_delay}") + + # Get two random samples from the dataset + idx1, idx2 = random.sample(range(len(self.dataset)), 2) + logger.info(f"Selected samples: {idx1}, {idx2}") + + # Get first sample - use its actions as prev_chunk + sample1 = self.dataset[idx1] + for key, value in sample1.items(): + if isinstance(value, torch.Tensor): + sample1[key] = value.unsqueeze(0).to(self.device) + + preprocessed_sample1 = self.preprocessor(sample1) + prev_chunk_left_over = preprocessed_sample1["action"][0, :, :25] + logger.info(f"Using actions from sample {idx1} as previous chunk: shape={prev_chunk_left_over.shape}") + + # Get second sample - generate actions for this one + sample2 = self.dataset[idx2] + for key, value in sample2.items(): + if isinstance(value, torch.Tensor): + sample2[key] = value.unsqueeze(0).to(self.device) + + preprocessed_sample2 = self.preprocessor(sample2) + logger.info(f"Generating actions for sample {idx2}") + + # Sample noise (use same noise for both RTC and non-RTC for fair comparison) + noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim) + noise = self.policy.model.sample_noise(noise_size, self.device) + noise_clone = noise.clone() + + # Create side-by-side figures for denoising visualization + fig_xt, axs_xt = plt.subplots(6, 2, figsize=(24, 12)) + fig_xt.suptitle("x_t Denoising: No RTC (left) vs RTC (right)", fontsize=16) + + fig_vt, axs_vt = plt.subplots(6, 2, figsize=(24, 12)) + fig_vt.suptitle("v_t Denoising: No RTC (left) vs RTC (right)", fontsize=16) + + fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12)) + fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16) + + # Generate actions WITHOUT RTC (plot on left column) + logger.info("Generating actions WITHOUT RTC") + self.policy.config.rtc_config.enabled = False + with torch.no_grad(): + no_rtc_actions = self.policy.predict_action_chunk( + preprocessed_sample2, + noise=noise, + inference_delay=self.cfg.inference_delay, + prev_chunk_left_over=prev_chunk_left_over, + viz_xt_axs=axs_xt[:, 0], # Left column for x_t + viz_vt_axs=axs_vt[:, 0], # Left column for v_t + ) + + # Generate actions WITH RTC (plot on right column) + logger.info("Generating actions WITH RTC") + self.policy.config.rtc_config.enabled = True + with torch.no_grad(): + rtc_actions = self.policy.predict_action_chunk( + preprocessed_sample2, + noise=noise_clone, + inference_delay=self.cfg.inference_delay, + prev_chunk_left_over=prev_chunk_left_over, + execution_horizon=self.cfg.rtc.execution_horizon, + viz_xt_axs=axs_xt[:, 1], # Right column for x_t + viz_vt_axs=axs_vt[:, 1], # Right column for v_t + viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t + ) + + # Set titles for denoising plots + for ax in axs_xt[:, 0]: + ax.set_title("No RTC" if ax == axs_xt[0, 0] else "", fontsize=12) + for ax in axs_xt[:, 1]: + ax.set_title("RTC" if ax == axs_xt[0, 1] else "", fontsize=12) + + for ax in axs_vt[:, 0]: + ax.set_title("No RTC" if ax == axs_vt[0, 0] else "", fontsize=12) + for ax in axs_vt[:, 1]: + ax.set_title("RTC" if ax == axs_vt[0, 1] else "", fontsize=12) + + for ax in axs_x1t[:, 0]: + ax.set_title("No RTC (N/A)" if ax == axs_x1t[0, 0] else "", fontsize=12) + for ax in axs_x1t[:, 1]: + ax.set_title("RTC" if ax == axs_x1t[0, 1] else "", fontsize=12) + + # Save denoising plots + fig_xt.tight_layout() + fig_xt.savefig("denoising_xt_comparison.png", dpi=150) + logger.info("Saved x_t denoising comparison to denoising_xt_comparison.png") + plt.close(fig_xt) + + fig_vt.tight_layout() + fig_vt.savefig("denoising_vt_comparison.png", dpi=150) + logger.info("Saved v_t denoising comparison to denoising_vt_comparison.png") + plt.close(fig_vt) + + fig_x1t.tight_layout() + fig_x1t.savefig("denoising_x1t_comparison.png", dpi=150) + logger.info("Saved x1_t predicted state & error comparison to denoising_x1t_comparison.png") + plt.close(fig_x1t) + + # Create side-by-side comparison: No RTC (left) vs RTC (right) + fig, axs = plt.subplots(6, 2, figsize=(24, 12)) + fig.suptitle("Final Action Comparison: No RTC (left) vs RTC (right)", fontsize=16) + + # Plot on left column (No RTC) + self._plot_actions( + axs[:, 0], + prev_chunk_left_over[0].cpu().numpy(), + no_rtc_actions[0].cpu().numpy(), + "No RTC", + ) + + # Plot on right column (RTC) + self._plot_actions( + axs[:, 1], + prev_chunk_left_over[0].cpu().numpy(), + rtc_actions[0].detach().cpu().numpy(), + "RTC", + ) + + plt.tight_layout() + plt.savefig("final_actions_comparison.png", dpi=150) + logger.info("Saved final actions comparison to final_actions_comparison.png") + plt.close(fig) + + # Visualize debug information if enabled + if self.cfg.enable_debug_viz and self.policy.rtc_processor is not None: + self._visualize_debug_info() + + logger.info("Evaluation completed successfully") + + def _plot_actions(self, axs, prev_chunk, predicted_actions, title): + """Plot actions comparison on given axes.""" + # Ensure arrays are 2D + if prev_chunk.ndim == 1: + prev_chunk = prev_chunk.reshape(1, -1) + if predicted_actions.ndim == 1: + predicted_actions = predicted_actions.reshape(1, -1) + + for j in range(min(prev_chunk.shape[-1], 6)): # Limit to 6 dimensions + axs[j].plot( + np.arange(prev_chunk.shape[0]), + prev_chunk[:, j], + color="green", + label="Previous Chunk", + ) + axs[j].plot( + np.arange(predicted_actions.shape[0]), + predicted_actions[:, j], + color="red" if "RTC" in title else "blue", + label=title, + ) + axs[j].set_ylabel("Joint angle", fontsize=14) + axs[j].grid() + axs[j].legend(loc="upper right", fontsize=14) + axs[j].set_title(title if j == 0 else "", fontsize=12) + if j == 2: + axs[j].set_xlabel("Step #", fontsize=16) + + def _visualize_debug_info(self): + """Visualize debug information from the RTC processor.""" + import os + + # Use proxy method to check if debug is enabled + if not self.policy.rtc_processor.is_debug_enabled(): + logger.warning("Debug tracking is disabled. Skipping debug visualization.") + return + + # Get tracker length using proxy method + if self.policy.rtc_processor.get_tracker_length() == 0: + logger.warning("No debug steps recorded. Skipping debug visualization.") + return + + # Create output directory + os.makedirs(self.cfg.output_dir, exist_ok=True) + logger.info(f"Saving debug visualizations to {self.cfg.output_dir}") + + # Still need direct access to tracker for visualization functions + # This is acceptable since RTCDebugVisualizer is part of the RTC package + tracker = self.policy.rtc_processor.tracker + + # Print statistics + RTCDebugVisualizer.print_debug_statistics(tracker) + + # Plot debug summary + summary_path = os.path.join(self.cfg.output_dir, "debug_summary.png") + RTCDebugVisualizer.plot_debug_summary( + tracker, + save_path=summary_path, + show=False, + ) + + # Plot correction heatmap + heatmap_path = os.path.join(self.cfg.output_dir, "correction_heatmap.png") + RTCDebugVisualizer.plot_correction_heatmap( + tracker, + save_path=heatmap_path, + show=False, + ) + + # Plot step-by-step comparison (last step) + step_path = os.path.join(self.cfg.output_dir, "step_comparison_last.png") + RTCDebugVisualizer.plot_step_by_step_comparison( + tracker, + step_idx=-1, + save_path=step_path, + show=False, + ) + + # Plot step-by-step comparison (first step) + step_path_first = os.path.join(self.cfg.output_dir, "step_comparison_first.png") + if self.policy.rtc_processor.get_tracker_length() > 0: + RTCDebugVisualizer.plot_step_by_step_comparison( + tracker, + step_idx=0, + save_path=step_path_first, + show=False, + ) + + logger.info(f"Debug visualizations saved to {self.cfg.output_dir}") + + +@parser.wrap() +def main(cfg: RTCEvalConfig): + """Main entry point for RTC evaluation.""" + # Set random seed for reproducibility + set_seed(cfg.seed) + + logger.info("=" * 80) + logger.info("RTC Dataset Evaluation") + logger.info(f"Config: {cfg}") + logger.info("=" * 80) + + evaluator = RTCEvaluator(cfg) + evaluator.run_evaluation() + + +if __name__ == "__main__": + main() diff --git a/examples/rtc/real_time_chunking_evaluate.py b/examples/rtc/real_time_chunking_evaluate.py new file mode 100644 index 000000000..8e1b3aa31 --- /dev/null +++ b/examples/rtc/real_time_chunking_evaluate.py @@ -0,0 +1,874 @@ +#!/usr/bin/env python + +""" +Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies. + +This script demonstrates: +1. Creating a robot/environment and policy (SmolVLA, Pi0, etc.) with RTC +2. Consuming actions from the policy while the robot/environment executes +3. Periodically requesting new action chunks in the background using threads +4. Managing action buffers and timing for real-time operation + +Usage: + # With real robot + python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 + + # With simulation environment + python rtc_demo.py --policy.path=lerobot/smolvla_base --env.type=pusht + + # With config file + python rtc_demo.py --config_path=path/to/config.json + + # With policy compilation for faster inference (recommended for production) + python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 --compile_policy=true + + # With aggressive compilation for maximum speed + python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 --compile_policy=true --compile_mode=max-autotune + +Performance Notes: + - torch.compile() is NOT supported on MPS (Apple Silicon) due to attention operation limitations + - For MPS optimization, reduce num_steps in the policy config (biggest speedup) + - CUDA devices will see 2-5x speedup with compilation enabled +""" + +import logging +import math +import sys +import time +import traceback +from dataclasses import dataclass, field +from threading import Event, Lock, Thread + +import numpy as np +import torch +from torch import Tensor + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.envs.configs import EnvConfig # noqa: F401 +from lerobot.envs.factory import make_env +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.processor.factory import ( + make_default_robot_action_processor, + make_default_robot_observation_processor, +) +from lerobot.rl.process import ProcessSignalHandler +from lerobot.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + so100_follower, + so101_follower, +) +from lerobot.robots.utils import make_robot_from_config +from lerobot.utils.constants import OBS_IMAGES +from lerobot.utils.hub import HubMixin +from lerobot.utils.utils import init_logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def tensor_stats_str(tensor: Tensor | None, name: str = "tensor") -> str: + """Generate readable statistics string for a tensor.""" + if tensor is None: + return f"{name}: None" + + stats = ( + f"{name}:\n" + f" shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}\n" + f" min={tensor.min().item():.6f}, max={tensor.max().item():.6f}\n" + f" mean={tensor.mean().item():.6f}, std={tensor.std().item():.6f}" + ) + return stats + + +def compare_tensors(tensor1: Tensor, tensor2: Tensor, name1: str = "tensor1", name2: str = "tensor2") -> str: + """Compare two tensors and return detailed difference statistics.""" + if tensor1 is None or tensor2 is None: + return f"Cannot compare: {name1}={tensor1 is not None}, {name2}={tensor2 is not None}" + + # Ensure same shape for comparison + if tensor1.shape != tensor2.shape: + return f"Shape mismatch: {name1}={tuple(tensor1.shape)} vs {name2}={tuple(tensor2.shape)}" + + diff = tensor1 - tensor2 + abs_diff = torch.abs(diff) + + # Per-timestep statistics + if len(diff.shape) >= 2: + # Shape is (batch, time, action_dim) or (time, action_dim) + per_timestep_mean = abs_diff.mean(dim=-1) # Average across action dimensions + + timestep_stats = "\n Per-timestep abs diff (averaged across action dims):\n" + if len(per_timestep_mean.shape) > 1: + # Has batch dimension + for batch_idx in range(per_timestep_mean.shape[0]): + timestep_stats += f" Batch {batch_idx}: [" + for t in range(min(10, per_timestep_mean.shape[1])): # Show first 10 timesteps + timestep_stats += f"{per_timestep_mean[batch_idx, t].item():.6f}, " + if per_timestep_mean.shape[1] > 10: + timestep_stats += "..." + timestep_stats += "]\n" + else: + timestep_stats += " [" + for t in range(min(10, len(per_timestep_mean))): + timestep_stats += f"{per_timestep_mean[t].item():.6f}, " + if len(per_timestep_mean) > 10: + timestep_stats += "..." + timestep_stats += "]\n" + else: + timestep_stats = "" + + result = ( + f"\nDifference: {name1} - {name2}:\n" + f" abs_diff: min={abs_diff.min().item():.6f}, max={abs_diff.max().item():.6f}\n" + f" abs_diff: mean={abs_diff.mean().item():.6f}, std={abs_diff.std().item():.6f}\n" + f" relative_diff: mean={abs_diff.mean().item() / (torch.abs(tensor2).mean().item() + 1e-8) * 100:.2f}%" + f"{timestep_stats}" + ) + + return result + + +class RobotWrapper: + def __init__(self, robot: Robot): + self.robot = robot + self.lock = Lock() + + def get_observation(self) -> dict[str, Tensor]: + with self.lock: + return self.robot.get_observation() + + def send_action(self, action: Tensor): + with self.lock: + self.robot.send_action(action) + + def observation_features(self) -> list[str]: + with self.lock: + return self.robot.observation_features + + def action_features(self) -> list[str]: + with self.lock: + return self.robot.action_features + + +class EnvWrapper: + """Wrapper for gym environments to provide same interface as RobotWrapper.""" + + def __init__(self, env, env_cfg: EnvConfig): + self.env = env + self.env_cfg = env_cfg + self.lock = Lock() + self._last_obs = None + self._episode_count = 0 + self._step_count = 0 + + # Initialize environment + obs, _ = self.env.reset() + self._last_obs = ( + obs[0] + if isinstance(obs, tuple) + or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict)) + else obs + ) + + # Cache feature names + self._observation_features = None + self._action_features = None + + def get_observation(self) -> dict[str, np.ndarray]: + """Get current observation from environment. + + Returns observations in the same format as robot.get_observation(): + a dict mapping feature names to numpy arrays. + """ + with self.lock: + if self._last_obs is None: + # Reset environment on first observation + obs, _ = self.env.reset() + self._last_obs = ( + obs[0] + if isinstance(obs, tuple) + or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict)) + else obs + ) + + # VectorEnv returns observations as numpy arrays in a batch + # Extract first element if it's a vectorized observation + obs = self._last_obs + if isinstance(obs, dict): + # Handle dict observations (extract first element from batch if needed) + result = {} + for key, value in obs.items(): + if isinstance(value, np.ndarray) and len(value.shape) > 0 and value.shape[0] == 1: + # Remove batch dimension for single env + result[key] = value[0] + else: + result[key] = value + return result + else: + # Handle array observations - shouldn't happen with our configs but handle it + return {"observation": obs[0] if len(obs.shape) > 1 else obs} + + def send_action(self, action: dict): + """Execute action in environment and update observation.""" + with self.lock: + # Convert action dict to array based on action_features + action_list = [] + for feature_name in self.action_features(): + if feature_name in action: + action_list.append(action[feature_name]) + + action_array = np.array(action_list) + + # VectorEnv expects actions with batch dimension + action_batch = action_array.reshape(1, -1) + + # Step environment + obs, _reward, terminated, truncated, _info = self.env.step(action_batch) + + # Extract from batch + self._last_obs = ( + obs[0] + if isinstance(obs, tuple) + or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict)) + else obs + ) + self._step_count += 1 + + # Check if episode is done (handle vectorized env format) + is_done = terminated[0] if isinstance(terminated, (np.ndarray, list)) else terminated + is_truncated = truncated[0] if isinstance(truncated, (np.ndarray, list)) else truncated + + # Reset if episode is done + if is_done or is_truncated: + logger.info(f"Episode {self._episode_count} finished after {self._step_count} steps") + obs, _ = self.env.reset() + self._last_obs = ( + obs[0] + if isinstance(obs, tuple) + or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict)) + else obs + ) + self._episode_count += 1 + self._step_count = 0 + + def observation_features(self) -> list[str]: + """Get observation feature names from environment config.""" + if self._observation_features is not None: + return self._observation_features + + with self.lock: + features = [] + for feature_name in self.env_cfg.features: + if feature_name != "action": + # Use the mapped name from features_map + mapped_name = self.env_cfg.features_map.get(feature_name, feature_name) + features.append(mapped_name) + + self._observation_features = features + return features + + def action_features(self) -> list[str]: + """Get action feature names from environment config.""" + if self._action_features is not None: + return self._action_features + + with self.lock: + # Return action dimension names + action_dim = self.env_cfg.features["action"].shape[0] + self._action_features = [f"action_{i}" for i in range(action_dim)] + return self._action_features + + +class ActionQueue: + def __init__(self, cfg: RTCConfig): + self.queue = None # Processed actions for robot rollout + self.original_queue = None # Original actions for RTC + self.lock = Lock() + self.last_index = 0 + self.cfg = cfg + + def get(self) -> Tensor | None: + with self.lock: + if self.queue is None or self.last_index >= len(self.queue): + return None + + action = self.queue[self.last_index] + self.last_index += 1 + return action.clone() + + def qsize(self) -> int: + # with self.lock: + if self.queue is None: + return 0 + length = len(self.queue) + + return length - self.last_index + + def empty(self) -> bool: + # with self.lock: + if self.queue is None: + return True + + length = len(self.queue) + return length - self.last_index + 1 <= 0 + + def get_action_index(self) -> int: + # with self.lock: + return self.last_index + + def get_left_over(self) -> Tensor: + """Get left over ORIGINAL actions for RTC prev_chunk_left_over.""" + with self.lock: + if self.original_queue is None: + return None + return self.original_queue[self.last_index :] + + def merge( + self, + original_actions: Tensor, + processed_actions: Tensor, + real_delay: int, + action_index_before_inference: int | None = 0, + ): + with self.lock: + self._check_delays(real_delay, action_index_before_inference) + + if self.cfg.enabled: + self._replace_actions_queue(original_actions, processed_actions, real_delay) + return + + self._append_actions_queue(original_actions, processed_actions) + + def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int): + self.original_queue = original_actions[real_delay:].clone() + self.queue = processed_actions[real_delay:].clone() + + logger.info(f"original_actions shape: {self.original_queue.shape}") + logger.info(f"processed_actions shape: {self.queue.shape}") + logger.info(f"real_delay: {real_delay}") + + self.last_index = 0 + + def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor): + if self.queue is None: + self.original_queue = original_actions.clone() + self.queue = processed_actions.clone() + return + + self.original_queue = torch.cat([self.original_queue, original_actions.clone()]) + self.original_queue = self.original_queue[self.last_index :] + + self.queue = torch.cat([self.queue, processed_actions.clone()]) + self.queue = self.queue[self.last_index :] + + self.last_index = 0 + + def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None): + if action_index_before_inference is None: + return + + indexes_diff = self.last_index - action_index_before_inference + if indexes_diff != real_delay: + # Let's check that action index difference (real delay calculated based on action queue) + # is the same as dealy calculated based on inference latency + logger.warning( + f"[ACTION_QUEUE] Indexes diff is not equal to real delay. Indexes diff: {indexes_diff}, real delay: {real_delay}" + ) + + +@dataclass +class RTCDemoConfig(HubMixin): + """Configuration for RTC demo with action chunking policies.""" + + # Policy configuration + policy: PreTrainedConfig | None = None + + # Robot configuration (mutually exclusive with env) + robot: RobotConfig | None = None + + # Environment configuration (mutually exclusive with robot) + env: EnvConfig | None = None + + # RTC configuration + rtc: RTCConfig = field( + default_factory=lambda: RTCConfig( + execution_horizon=10, + max_guidance_weight=1.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + ) + ) + + # Demo parameters + duration: float = 30.0 # Duration to run the demo (seconds) + fps: float = 10.0 # Action execution frequency (Hz) + + # Compute device + device: str | None = None # Device to run on (cuda, cpu, auto) + + # Compilation options + compile_policy: bool = ( + False # Compile policy with torch.compile() for faster inference (not supported on MPS) + ) + compile_mode: str = "default" # Compilation mode: default, reduce-overhead, max-autotune + + # Alternative optimization options (work on all devices including MPS) + use_channels_last: bool = False # Use channels_last memory format for images (faster on some devices) + enable_cudnn_benchmark: bool = True # Enable cuDNN benchmarking (CUDA only) + + # Get new actions horizon. The amount of executed steps after which will be requested new actions. + # It should be higher than inference delay + execution horizon. + action_queue_size_to_get_new_actions: int = 30 + + # Task to execute + task: str = field(default="", metadata={"help": "Task to execute"}) + + # Debug options + verbose_rtc_comparison: bool = True # Enable detailed RTC comparison output + + def __post_init__(self): + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + else: + raise ValueError("Policy path is required") + + # Validate that either robot or env is provided, but not both + if self.robot is None and self.env is None: + raise ValueError("Either robot or env configuration must be provided") + if self.robot is not None and self.env is not None: + raise ValueError("Cannot specify both robot and env configuration. Choose one.") + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + + +def is_image_key(k: str) -> bool: + return k.startswith(OBS_IMAGES) + + +def get_actions( + policy, + robot: RobotWrapper, + robot_observation_processor, + action_queue: ActionQueue, + shutdown_event: Event, + cfg: RTCDemoConfig, +): + """Thread function to request action chunks from the policy. + + Args: + policy: The policy instance (SmolVLA, Pi0, etc.) + robot: The robot instance for getting observations + robot_observation_processor: Processor for raw robot observations + action_queue: Queue to put new action chunks + shutdown_event: Event to signal shutdown + cfg: Demo configuration + """ + try: + logger.info("[GET_ACTIONS] Starting get actions thread") + + latency_tracker = LatencyTracker() # Track latency of action chunks + fps = cfg.fps + time_per_chunk = 1.0 / fps + + dataset_features = hw_to_dataset_features(robot.observation_features(), "observation") + policy_device = policy.config.device + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + preprocessor_overrides={ + "device_processor": {"device": cfg.policy.device}, + }, + ) + + get_actions_threshold = cfg.action_queue_size_to_get_new_actions + + if not cfg.rtc.enabled: + get_actions_threshold = 0 + + while not shutdown_event.is_set(): + if action_queue.qsize() <= get_actions_threshold: + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) + + obs = robot.get_observation() + + # Apply robot observation processor + obs_processed = robot_observation_processor(obs) + + obs_with_policy_features = build_dataset_frame( + dataset_features, obs_processed, prefix="observation" + ) + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = ( + obs_with_policy_features[name].type(torch.float32) / 255 + ) + obs_with_policy_features[name] = ( + obs_with_policy_features[name].permute(2, 0, 1).contiguous() + ) + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) + obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + + # for k, v in obs_with_policy_features.items(): + # if isinstance(v, np.ndarray): + # obs_with_policy_features[k] = torch.from_numpy(v).to(policy_device) + + # if is_image_key(k): + # obs_with_policy_features[k] = obs_with_policy_features[k].type(torch.float32) / 255 + # obs_with_policy_features[k] = obs_with_policy_features[k].permute(2, 0, 1).unsqueeze(0) + # elif isinstance(obs_with_policy_features[k], torch.Tensor): + # obs_with_policy_features[k] = obs_with_policy_features[k].unsqueeze(0) + + obs_with_policy_features["task"] = cfg.task + + preproceseded_obs = preprocessor(obs_with_policy_features) + + noise_size = (1, policy.config.chunk_size, policy.config.max_action_dim) + noise = policy.model.sample_noise(noise_size, policy_device) + noise_clone = noise.clone() + + # Generate actions WITHOUT RTC for comparison (if verbose mode enabled) + if cfg.verbose_rtc_comparison: + policy.config.rtc_config.enabled = False + not_rtc_actions = policy.predict_action_chunk( + preproceseded_obs, + noise=noise, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + policy.config.rtc_config.enabled = True + + # Generate actions WITH RTC + actions = policy.predict_action_chunk( + preproceseded_obs, + noise=noise_clone if cfg.verbose_rtc_comparison else noise, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + # Store original actions (before postprocessing) for RTC + original_actions = actions.squeeze(0).clone() + + # Detailed comparison output (if verbose mode enabled) + if cfg.verbose_rtc_comparison: + logger.info("=" * 80) + logger.info("RTC ACTION COMPARISON") + logger.info("=" * 80) + + # Print detailed statistics + logger.info("\n" + tensor_stats_str(not_rtc_actions, "not_rtc_actions (without RTC)")) + logger.info("\n" + tensor_stats_str(actions, "actions (with RTC)")) + logger.info( + "\n" + tensor_stats_str(prev_actions, "prev_actions (leftover from previous chunk)") + ) + + # Compare RTC vs non-RTC actions + logger.info( + compare_tensors(actions, not_rtc_actions, "actions (RTC)", "not_rtc_actions (no RTC)") + ) + + to_non_rtc_diff = actions - not_rtc_actions + + print("to_non_rtc_diff", to_non_rtc_diff) + if prev_actions is not None: + prev_padded = torch.zeros_like(actions) + prev_padded[:, : prev_actions.shape[1], :] = prev_actions + to_prev_diff = actions - prev_padded + print("to_prev_diff", to_prev_diff) + print("=" * 80) + + postprocessed_actions = postprocessor(actions) + + postprocessed_actions = postprocessed_actions.squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay: + logger.warning( + "[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon." + ) + + logger.debug(f"[GET_ACTIONS] new_delay: {new_delay}") + logger.debug(f"[GET_ACTIONS] original_actions shape: {original_actions.shape}") + logger.debug(f"[GET_ACTIONS] postprocessed_actions shape: {postprocessed_actions.shape}") + logger.debug(f"[GET_ACTIONS] action_index_before_inference: {action_index_before_inference}") + + action_queue.merge( + original_actions, postprocessed_actions, new_delay, action_index_before_inference + ) + else: + # Small sleep to prevent busy waiting + time.sleep(0.1) + + logger.info("[GET_ACTIONS] get actions thread shutting down") + except Exception as e: + logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}") + logger.error(traceback.format_exc()) + sys.exit(1) + + +def actor_control( + robot: RobotWrapper, + robot_action_processor, + action_queue: ActionQueue, + shutdown_event: Event, + cfg: RTCDemoConfig, +): + """Thread function to execute actions on the robot. + + Args: + robot: The robot instance + action_queue: Queue to get actions from + shutdown_event: Event to signal shutdown + cfg: Demo configuration + """ + try: + logger.info("[ACTOR] Starting actor thread") + + action_count = 0 + action_interval = 1.0 / cfg.fps + + while not shutdown_event.is_set(): + start_time = time.perf_counter() + + # Try to get an action from the queue with timeout + action = action_queue.get() + + if action is not None: + action = action.cpu() + action = {key: action[i].item() for i, key in enumerate(robot.action_features())} + action = robot_action_processor((action, None)) + robot.send_action(action) + + action_count += 1 + + dt_s = time.perf_counter() - start_time + time.sleep((action_interval - dt_s) - 0.001) + + logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}") + except Exception as e: + logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}") + logger.error(traceback.format_exc()) + sys.exit(1) + + +def stop_by_duration(shutdown_event: Event, cfg: RTCDemoConfig): + """Stop the demo by duration.""" + time.sleep(cfg.duration) + shutdown_event.set() + + +@parser.wrap() +def demo_cli(cfg: RTCDemoConfig): + """Main entry point for RTC demo with draccus configuration.""" + + # Initialize logging + init_logging() + + logger.info(f"Using device: {cfg.device}") + + # Setup signal handler for graceful shutdown + signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False) + shutdown_event = signal_handler.shutdown_event + + policy = None + robot = None + vec_env = None + get_actions_thread = None + actor_thread = None + + policy_class = get_policy_class(cfg.policy.type) + policy = policy_class.from_pretrained(cfg.policy.pretrained_path) + + # Turn on RTC + policy.config.rtc_config = cfg.rtc + + # Init RTC processort, as by default if RTC disabled in the config + # The processor won't be created + policy.init_rtc_processor(verbose=cfg.verbose_rtc_comparison) + + assert policy.name in ["smolvla"], "Only smolvla are supported for RTC" + + policy = policy.to(cfg.device) + policy.eval() + + # Apply memory format optimizations + if cfg.use_channels_last: + logger.info("Converting model to channels_last memory format") + try: + # Convert vision encoder to channels_last for better performance + if hasattr(policy, "vision_encoder"): + policy.vision_encoder = policy.vision_encoder.to(memory_format=torch.channels_last) + logger.info("Successfully converted to channels_last format") + except Exception as e: + logger.warning(f"Failed to convert to channels_last: {e}") + + # Enable cuDNN benchmarking for CUDA + if cfg.enable_cudnn_benchmark and cfg.device == "cuda": + torch.backends.cudnn.benchmark = True + logger.info("Enabled cuDNN benchmarking") + + # Compile policy if requested + if cfg.compile_policy: + # Check if device is MPS - torch.compile has issues with MPS backend + if cfg.device == "mps": + logger.warning("torch.compile() is not stable with MPS backend (Apple Silicon)") + logger.warning("Skipping compilation. For better performance on MPS:") + logger.warning(" 1. Use torch.float32 instead of bfloat16") + logger.warning(" 2. Ensure model uses contiguous memory layouts") + logger.warning(" 3. Consider using CUDA if available") + else: + logger.info(f"Compiling policy with mode: {cfg.compile_mode}") + logger.info("First inference will be slower due to compilation, subsequent calls will be faster") + + try: + # Compile the predict_action_chunk method + policy.predict_action_chunk = torch.compile( + policy.predict_action_chunk, + mode=cfg.compile_mode, + fullgraph=False, # Allow graph breaks for flexibility + backend="inductor", # Use inductor backend + ) + logger.info("Policy compiled successfully") + except Exception as e: + logger.warning(f"Failed to compile policy: {e}") + logger.warning("Continuing without compilation") + + # Create robot or environment + if cfg.robot is not None: + logger.info(f"Initializing robot: {cfg.robot.type}") + robot = make_robot_from_config(cfg.robot) + robot.connect() + agent_wrapper = RobotWrapper(robot) + else: + logger.info(f"Initializing environment: {cfg.env.type}") + # Create environment using make_env + env_dict = make_env(cfg.env, n_envs=1, use_async_envs=False) + + # Validate environment structure: should have exactly one suite + if len(env_dict) != 1: + raise ValueError( + f"Expected exactly one environment suite, but got {len(env_dict)}. " + f"Suites: {list(env_dict.keys())}" + ) + + # Extract the actual env from the dict structure {suite: {task_id: vec_env}} + suite_name = list(env_dict.keys())[0] + task_dict = env_dict[suite_name] + + # Validate task structure: should have exactly one task + if len(task_dict) != 1: + raise ValueError( + f"Expected exactly one task in suite '{suite_name}', but got {len(task_dict)}. " + f"Tasks: {list(task_dict.keys())}" + ) + + vec_env = task_dict[0] + logger.info(f"Created environment: suite='{suite_name}', task_id=0, num_envs={vec_env.num_envs}") + + # Validate that we have exactly 1 parallel environment + if vec_env.num_envs != 1: + raise ValueError( + f"Expected exactly 1 parallel environment, but got {vec_env.num_envs}. " + f"The EnvWrapper is designed for single environment instances." + ) + + agent_wrapper = EnvWrapper(vec_env, cfg.env) + + # Create robot observation processor + robot_observation_processor = make_default_robot_observation_processor() + robot_action_processor = make_default_robot_action_processor() + + # Create action queue for communication between threads + action_queue = ActionQueue(cfg.rtc) + + # Start chunk requester thread + get_actions_thread = Thread( + target=get_actions, + args=(policy, agent_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg), + daemon=True, + name="GetActions", + ) + get_actions_thread.start() + logger.info("Started get actions thread") + + # Start action executor thread + actor_thread = Thread( + target=actor_control, + args=(agent_wrapper, robot_action_processor, action_queue, shutdown_event, cfg), + daemon=True, + name="Actor", + ) + actor_thread.start() + logger.info("Started actor thread") + + logger.info("Started stop by duration thread") + + # Main thread monitors for duration or shutdown + logger.info(f"Running demo for {cfg.duration} seconds...") + start_time = time.time() + + while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration: + time.sleep(10) + + # Log queue status periodically + if int(time.time() - start_time) % 5 == 0: + logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}") + + if time.time() - start_time > cfg.duration: + break + + logger.info("Demo duration reached or shutdown requested") + + # Signal shutdown + shutdown_event.set() + + # Wait for threads to finish + if get_actions_thread and get_actions_thread.is_alive(): + logger.info("Waiting for chunk requester thread to finish...") + get_actions_thread.join() + + if actor_thread and actor_thread.is_alive(): + logger.info("Waiting for action executor thread to finish...") + actor_thread.join() + + # Cleanup robot or environment + if cfg.robot is not None: + if robot: + robot.disconnect() + logger.info("Robot disconnected") + else: + # Close environment + if vec_env: + vec_env.close() + logger.info("Environment closed") + + logger.info("Cleanup completed") + + +if __name__ == "__main__": + demo_cli() + logging.info("RTC demo finished") diff --git a/examples/rtc/run_dataset_evaluation.sh b/examples/rtc/run_dataset_evaluation.sh new file mode 100755 index 000000000..81370682f --- /dev/null +++ b/examples/rtc/run_dataset_evaluation.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# Example script to run RTC evaluation on dataset +# This shows different usage scenarios + +set -e # Exit on error + +POLICY_PATH="lerobot/smolvla_base" +DATASET="lerobot/pusht" +DEVICE="cuda" # Change to "cpu" or "mps" if needed + +echo "========================================" +echo "RTC Dataset Evaluation Examples" +echo "========================================" + +# Example 1: Quick evaluation (100 samples, every step) +echo -e "\n[Example 1] Quick evaluation - 100 samples, every step" +python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path="${POLICY_PATH}" \ + --dataset.repo_id="${DATASET}" \ + --num_iterations=100 \ + --skip_steps=1 \ + --device="${DEVICE}" \ + --output_path="results/rtc_eval_quick.json" + +# Example 2: Simulating realistic inference delay (every 3rd step) +echo -e "\n[Example 2] Realistic inference delay - 200 samples, every 3rd step" +python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path="${POLICY_PATH}" \ + --dataset.repo_id="${DATASET}" \ + --num_iterations=200 \ + --skip_steps=3 \ + --rtc.execution_horizon=10 \ + --device="${DEVICE}" \ + --output_path="results/rtc_eval_delay3.json" + +# Example 3: Higher inference delay (every 5th step) +echo -e "\n[Example 3] High inference delay - 200 samples, every 5th step" +python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path="${POLICY_PATH}" \ + --dataset.repo_id="${DATASET}" \ + --num_iterations=200 \ + --skip_steps=5 \ + --rtc.execution_horizon=12 \ + --device="${DEVICE}" \ + --output_path="results/rtc_eval_delay5.json" + +# Example 4: Testing different RTC configurations +echo -e "\n[Example 4] Different RTC config - LINEAR schedule" +python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path="${POLICY_PATH}" \ + --dataset.repo_id="${DATASET}" \ + --num_iterations=100 \ + --skip_steps=3 \ + --rtc.execution_horizon=8 \ + --rtc.prefix_attention_schedule=LINEAR \ + --rtc.max_guidance_weight=5.0 \ + --device="${DEVICE}" \ + --output_path="results/rtc_eval_linear.json" + +# Example 5: Verbose mode for debugging +echo -e "\n[Example 5] Verbose mode - 20 samples with detailed output" +python examples/rtc/evaluate_rtc_on_dataset.py \ + --policy.path="${POLICY_PATH}" \ + --dataset.repo_id="${DATASET}" \ + --num_iterations=20 \ + --skip_steps=3 \ + --device="${DEVICE}" \ + --verbose=true \ + --output_path="results/rtc_eval_verbose.json" + +echo -e "\n========================================" +echo "All evaluations completed!" +echo "Results saved in results/ directory" +echo "========================================" diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 11a1f8d74..18359ef05 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -43,3 +43,10 @@ class NormalizationMode(str, Enum): class PolicyFeature: type: FeatureType shape: tuple[int, ...] + + +class RTCAttentionSchedule(str, Enum): + ZEROS = "ZEROS" + ONES = "ONES" + LINEAR = "LINEAR" + EXP = "EXP" diff --git a/src/lerobot/policies/rtc/README.md b/src/lerobot/policies/rtc/README.md new file mode 100644 index 000000000..5d5708ae9 --- /dev/null +++ b/src/lerobot/policies/rtc/README.md @@ -0,0 +1,28 @@ +# Real-Time Chunking (RTC) Module + +This module implements Real-Time Chunking and related adaptive inference techniques for robotics policies in LeRobot. + +## Overview + +Real-Time Chunking (RTC) addresses the challenge of real-time inference in action chunking policies by treating chunk generation as an inpainting problem. It strategically handles overlapping timesteps between action chunks using prefix attention mechanisms. + +It is particularly effective for handling long-horizon inference in robotics policies. + +## Integration with Policies + +RTC can be integrated with any policy that supports flow mathicng for chunking: + +- **SmolVLA**: Vision-language-action model with RTC support +- **Pi0**: Action prediction model with adaptive chunking + +## Original Implementation + +This implementation is based on Physical Intelligence's Kinetix RTC: + +- [Original RTC implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214) +- [Kinetix GitHub Repository](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) + +## References + +- [Real Time Chunking Paper](https://www.physicalintelligence.company/research/real_time_chunking) +- [Physical Intelligence Kinetix](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) diff --git a/src/lerobot/policies/rtc/configuration_rtc.py b/src/lerobot/policies/rtc/configuration_rtc.py new file mode 100644 index 000000000..7794c6b01 --- /dev/null +++ b/src/lerobot/policies/rtc/configuration_rtc.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes. + +Based on: +- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking +""" + +from dataclasses import dataclass + +from lerobot.configs.types import RTCAttentionSchedule + + +@dataclass +class RTCConfig: + """Configuration for Real Time Chunking (RTC) inference. + + RTC improves real-time inference by treating chunk generation as an inpainting problem, + strategically handling overlapping timesteps between action chunks using prefix attention. + """ + + # Infrastructure + enabled: bool = False + + # Core RTC settings + # Todo change to exp + prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR + max_guidance_weight: float = 1.0 + execution_horizon: int = 10 + + # Debug settings + debug: bool = False + debug_maxlen: int = 100 + + def __post_init__(self): + """Validate RTC configuration parameters.""" + if self.max_guidance_weight <= 0: + raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}") + if self.debug_maxlen <= 0: + raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}") diff --git a/src/lerobot/policies/rtc/debug_handler.py b/src/lerobot/policies/rtc/debug_handler.py new file mode 100644 index 000000000..dd3040016 --- /dev/null +++ b/src/lerobot/policies/rtc/debug_handler.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Debug information handler for Real-Time Chunking (RTC).""" + +from dataclasses import dataclass, field +from typing import Any + +import torch +from torch import Tensor + + +@dataclass +class DebugStep: + """Container for debug information from a single denoising step. + + Attributes: + step_idx (int): Step index/counter. + x_t (Tensor | None): Current latent/state tensor. + v_t (Tensor | None): Velocity from denoiser. + x1_t (Tensor | None): Denoised prediction (x_t - time * v_t). + correction (Tensor | None): Correction gradient tensor. + err (Tensor | None): Weighted error term. + weights (Tensor | None): Prefix attention weights. + guidance_weight (float | Tensor | None): Applied guidance weight. + time (float | Tensor | None): Time parameter. + inference_delay (int | None): Inference delay parameter. + execution_horizon (int | None): Execution horizon parameter. + metadata (dict[str, Any]): Additional metadata. + """ + + step_idx: int = 0 + x_t: Tensor | None = None + v_t: Tensor | None = None + x1_t: Tensor | None = None + correction: Tensor | None = None + err: Tensor | None = None + weights: Tensor | None = None + guidance_weight: float | Tensor | None = None + time: float | Tensor | None = None + inference_delay: int | None = None + execution_horizon: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self, include_tensors: bool = False) -> dict[str, Any]: + """Convert debug step to dictionary. + + Args: + include_tensors (bool): If True, include tensor values. If False, only include + tensor statistics (shape, mean, std, min, max). + + Returns: + Dictionary representation of the debug step. + """ + result = { + "step_idx": self.step_idx, + "guidance_weight": ( + self.guidance_weight.item() + if isinstance(self.guidance_weight, Tensor) + else self.guidance_weight + ), + "time": self.time.item() if isinstance(self.time, Tensor) else self.time, + "inference_delay": self.inference_delay, + "execution_horizon": self.execution_horizon, + "metadata": self.metadata.copy(), + } + + # Add tensor information + tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"] + for field_name in tensor_fields: + tensor = getattr(self, field_name) + if tensor is not None: + if include_tensors: + result[field_name] = tensor.detach().cpu() + else: + result[f"{field_name}_stats"] = { + "shape": tuple(tensor.shape), + "mean": tensor.mean().item(), + "std": tensor.std().item(), + "min": tensor.min().item(), + "max": tensor.max().item(), + } + + return result + + +class Tracker: + """Collects and manages debug information for RTC processing. + + This tracker stores debug information from recent denoising steps in a dictionary, + using time as the key for efficient lookups and updates. + + Args: + enabled (bool): Whether debug collection is enabled. + maxlen (int | None): Optional sliding window size. If provided, only the + most recent ``maxlen`` debug steps are kept. If ``None``, keeps all. + """ + + def __init__(self, enabled: bool = False, maxlen: int = 100): + self.enabled = enabled + self._steps = {} if enabled else None # Dictionary with time as key + self._maxlen = maxlen + self._step_counter = 0 + + def reset(self) -> None: + """Clear all recorded debug information.""" + if self.enabled and self._steps is not None: + self._steps.clear() + self._step_counter = 0 + + def track( + self, + time: float | Tensor, + x_t: Tensor | None = None, + v_t: Tensor | None = None, + x1_t: Tensor | None = None, + correction: Tensor | None = None, + err: Tensor | None = None, + weights: Tensor | None = None, + guidance_weight: float | Tensor | None = None, + inference_delay: int | None = None, + execution_horizon: int | None = None, + **metadata, + ) -> None: + """Track debug information for a denoising step at a given time. + + If a step with the given time already exists, it will be updated with the new data. + Otherwise, a new step will be created. Only non-None fields are updated/set. + + Args: + time (float | Tensor): Time parameter - used as the key to identify the step. + x_t (Tensor | None): Current latent/state tensor. + v_t (Tensor | None): Velocity from denoiser. + x1_t (Tensor | None): Denoised prediction. + correction (Tensor | None): Correction gradient tensor. + err (Tensor | None): Weighted error term. + weights (Tensor | None): Prefix attention weights. + guidance_weight (float | Tensor | None): Applied guidance weight. + inference_delay (int | None): Inference delay parameter. + execution_horizon (int | None): Execution horizon parameter. + **metadata: Additional metadata to store. + """ + if not self.enabled: + return + + # Convert time to float and round to avoid float precision issues + time_value = time.item() if isinstance(time, Tensor) else time + time_key = round(time_value, 6) # Use rounded time as dictionary key + + # Check if step with this time already exists + if time_key in self._steps: + # Update existing step with non-None fields + existing_step = self._steps[time_key] + if x_t is not None: + existing_step.x_t = x_t.detach().clone() + if v_t is not None: + existing_step.v_t = v_t.detach().clone() + if x1_t is not None: + existing_step.x1_t = x1_t.detach().clone() + if correction is not None: + existing_step.correction = correction.detach().clone() + if err is not None: + existing_step.err = err.detach().clone() + if weights is not None: + existing_step.weights = weights.detach().clone() + if guidance_weight is not None: + existing_step.guidance_weight = guidance_weight + if inference_delay is not None: + existing_step.inference_delay = inference_delay + if execution_horizon is not None: + existing_step.execution_horizon = execution_horizon + if metadata: + existing_step.metadata.update(metadata) + else: + # Create new step + step = DebugStep( + step_idx=self._step_counter, + x_t=x_t.detach().clone() if x_t is not None else None, + v_t=v_t.detach().clone() if v_t is not None else None, + x1_t=x1_t.detach().clone() if x1_t is not None else None, + correction=correction.detach().clone() if correction is not None else None, + err=err.detach().clone() if err is not None else None, + weights=weights.detach().clone() if weights is not None else None, + guidance_weight=guidance_weight, + time=time_value, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + metadata=metadata, + ) + + # Add to dictionary + self._steps[time_key] = step + self._step_counter += 1 + + # Enforce maxlen if set + if self._maxlen is not None and len(self._steps) > self._maxlen: + # Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order) + oldest_key = next(iter(self._steps)) + del self._steps[oldest_key] + + def get_recent_steps(self, n: int = 1) -> list[DebugStep]: + """Get the n most recent debug steps. + + Args: + n (int): Number of recent steps to retrieve. + + Returns: + List of DebugStep objects (may be empty if disabled or no steps recorded). + """ + if not self.enabled or self._steps is None: + return [] + + # Get all values and return the last n + all_steps = list(self._steps.values()) + return all_steps[-n:] + + def get_all_steps(self) -> list[DebugStep]: + """Get all recorded debug steps. + + Returns: + List of all DebugStep objects (may be empty if disabled). + """ + if not self.enabled or self._steps is None: + return [] + + return list(self._steps.values()) + + def get_step_stats_summary(self) -> dict[str, Any]: + """Get summary statistics across all recorded steps. + + Returns: + Dictionary containing aggregate statistics. + """ + if not self.enabled or self._steps is None or len(self._steps) == 0: + return {"enabled": self.enabled, "total_steps": 0} + + # Aggregate statistics from dictionary values + corrections = [s.correction for s in self._steps.values() if s.correction is not None] + errors = [s.err for s in self._steps.values() if s.err is not None] + guidance_weights = [s.guidance_weight for s in self._steps.values() if s.guidance_weight is not None] + + summary = { + "enabled": self.enabled, + "total_steps": len(self._steps), + "step_counter": self._step_counter, + } + + if corrections: + correction_norms = torch.tensor([c.norm().item() for c in corrections]) + summary["correction_norms"] = { + "mean": correction_norms.mean().item(), + "std": correction_norms.std().item(), + "min": correction_norms.min().item(), + "max": correction_norms.max().item(), + } + + if errors: + error_norms = torch.tensor([e.norm().item() for e in errors]) + summary["error_norms"] = { + "mean": error_norms.mean().item(), + "std": error_norms.std().item(), + "min": error_norms.min().item(), + "max": error_norms.max().item(), + } + + if guidance_weights: + gw_tensor = torch.tensor([gw.item() if isinstance(gw, Tensor) else gw for gw in guidance_weights]) + summary["guidance_weights"] = { + "mean": gw_tensor.mean().item(), + "std": gw_tensor.std().item(), + "min": gw_tensor.min().item(), + "max": gw_tensor.max().item(), + } + + return summary + + def export_to_dict(self, include_tensors: bool = False) -> dict[str, Any]: + """Export all debug information to a dictionary. + + Args: + include_tensors (bool): If True, include full tensor values. If False, + only include tensor statistics. + + Returns: + Dictionary containing all debug information. + """ + if not self.enabled or self._steps is None: + return {"enabled": False, "steps": []} + + return { + "enabled": True, + "total_steps": len(self._steps), + "step_counter": self._step_counter, + "steps": [step.to_dict(include_tensors=include_tensors) for step in self._steps.values()], + } + + def __len__(self) -> int: + """Return the number of recorded debug steps.""" + if not self.enabled or self._steps is None: + return 0 + return len(self._steps) + + @staticmethod + def tensor_stats(tensor: Tensor, name: str = "tensor") -> str: + """Generate readable statistics string for a tensor. + + Args: + tensor: Input tensor + name: Name to display + + Returns: + Formatted string with shape and statistics + """ + if tensor is None: + return f"{name}: None" + + stats = ( + f"{name}: shape={tuple(tensor.shape)}, " + f"dtype={tensor.dtype}, " + f"device={tensor.device}, " + f"min={tensor.min().item():.4f}, " + f"max={tensor.max().item():.4f}, " + f"mean={tensor.mean().item():.4f}, " + f"std={tensor.std().item():.4f}" + ) + return stats diff --git a/src/lerobot/policies/rtc/debug_visualizer.py b/src/lerobot/policies/rtc/debug_visualizer.py new file mode 100644 index 000000000..a9c5ee86c --- /dev/null +++ b/src/lerobot/policies/rtc/debug_visualizer.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualization utilities for RTC debug information.""" + +import matplotlib.pyplot as plt +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from lerobot.policies.rtc.debug_handler import Tracker + + +class RTCDebugVisualizer: + """Visualizer for RTC debug information. + + This class provides methods to visualize debug information collected by the Tracker, + including corrections, errors, weights, and guidance weights over denoising steps. + """ + + @staticmethod + def plot_waypoints( + axes, + tensor, + start_from: int = 0, + color: str = "blue", + label: str = "", + alpha: float = 0.7, + linewidth: float = 2, + marker: str | None = None, + markersize: int = 4, + ): + """Plot trajectories across multiple dimensions. + + This function plots a tensor's values across time for multiple dimensions, + with each dimension plotted on a separate axis. + + Args: + axes: Array of matplotlib axes (one for each dimension). + tensor: The tensor to plot (can be torch.Tensor or numpy array). + Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims). + start_from: Starting index for the x-axis. + color: Color for the plot lines. + label: Label for the plot legend. + alpha: Transparency level for the plot. + linewidth: Width of the plot lines. + marker: Marker style for data points (e.g., 'o', 's', '^'). + markersize: Size of the markers. + """ + import numpy as np + import torch + + # Handle None tensor + if tensor is None: + return + + # Convert tensor to numpy if needed + tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor + + # Handle different tensor shapes + if tensor_np.ndim == 3: + # If batch dimension present, take first batch + tensor_np = tensor_np[0] + elif tensor_np.ndim == 1: + # If 1D, reshape to (time_steps, 1) + tensor_np = tensor_np.reshape(-1, 1) + + # Get dimensions + time_steps, num_dims = tensor_np.shape + + # Create x-axis indices + x_indices = np.arange(start_from, start_from + time_steps) + + # Plot each dimension on its corresponding axis + num_axes = len(axes) if hasattr(axes, "__len__") else 1 + for dim_idx in range(min(num_dims, num_axes)): + ax = axes[dim_idx] if hasattr(axes, "__len__") else axes + + # Plot the trajectory + if marker: + ax.plot( + x_indices, + tensor_np[:, dim_idx], + color=color, + label=label if dim_idx == 0 else "", # Only show label once + alpha=alpha, + linewidth=linewidth, + marker=marker, + markersize=markersize, + ) + else: + ax.plot( + x_indices, + tensor_np[:, dim_idx], + color=color, + label=label if dim_idx == 0 else "", # Only show label once + alpha=alpha, + linewidth=linewidth, + ) + + # Add grid and labels if not already present + if not ax.xaxis.get_label().get_text(): + ax.set_xlabel("Step", fontsize=10) + if not ax.yaxis.get_label().get_text(): + ax.set_ylabel(f"Dim {dim_idx}", fontsize=10) + ax.grid(True, alpha=0.3) + + # Add legend if label provided and this is the first dimension + if label and dim_idx == 0: + ax.legend(loc="best", fontsize=8) + + @staticmethod + def plot_debug_summary( + tracker: Tracker, + save_path: str | None = None, + show: bool = False, + figsize: tuple[int, int] = (16, 12), + ) -> Figure: + """Create a comprehensive summary plot of debug information. + + Args: + tracker (Tracker): Tracker with recorded steps. + save_path (str | None): Path to save the figure. If None, figure is not saved. + show (bool): Whether to display the figure. + figsize (tuple[int, int]): Figure size in inches (width, height). + + Returns: + Figure: The matplotlib figure object. + """ + if not tracker.enabled or len(tracker) == 0: + print("Tracker is disabled or has no recorded steps.") + return None + + steps = tracker.get_all_steps() + num_steps = len(steps) + + # Create figure with subplots + fig, axes = plt.subplots(3, 2, figsize=figsize) + fig.suptitle(f"RTC Debug Summary ({num_steps} steps)", fontsize=16, fontweight="bold") + + # Plot 1: Correction norms over steps + ax = axes[0, 0] + correction_norms = [step.correction.norm().item() for step in steps if step.correction is not None] + if correction_norms: + ax.plot(correction_norms, marker="o", linewidth=2, markersize=4) + ax.set_xlabel("Step Index", fontsize=12) + ax.set_ylabel("Correction Norm", fontsize=12) + ax.set_title("Correction Magnitude Over Steps", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 2: Error norms over steps + ax = axes[0, 1] + error_norms = [step.err.norm().item() for step in steps if step.err is not None] + if error_norms: + ax.plot(error_norms, marker="o", linewidth=2, markersize=4, color="orange") + ax.set_xlabel("Step Index", fontsize=12) + ax.set_ylabel("Error Norm", fontsize=12) + ax.set_title("Error Magnitude Over Steps", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 3: Guidance weights over steps + ax = axes[1, 0] + guidance_weights = [ + step.guidance_weight.item() if isinstance(step.guidance_weight, Tensor) else step.guidance_weight + for step in steps + if step.guidance_weight is not None + ] + if guidance_weights: + ax.plot(guidance_weights, marker="o", linewidth=2, markersize=4, color="green") + ax.set_xlabel("Step Index", fontsize=12) + ax.set_ylabel("Guidance Weight", fontsize=12) + ax.set_title("Guidance Weight Over Steps", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 4: Time parameter over steps + ax = axes[1, 1] + times = [ + step.time.item() if isinstance(step.time, Tensor) else step.time + for step in steps + if step.time is not None + ] + if times: + ax.plot(times, marker="o", linewidth=2, markersize=4, color="purple") + ax.set_xlabel("Step Index", fontsize=12) + ax.set_ylabel("Time Parameter", fontsize=12) + ax.set_title("Time Parameter Over Steps", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 5: Correction vs Error relationship + ax = axes[2, 0] + if correction_norms and error_norms: + ax.scatter(error_norms, correction_norms, alpha=0.6, s=50) + ax.set_xlabel("Error Norm", fontsize=12) + ax.set_ylabel("Correction Norm", fontsize=12) + ax.set_title("Correction vs Error", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 6: Prefix attention weights visualization (last step) + ax = axes[2, 1] + last_step = steps[-1] + if last_step.weights is not None: + weights = last_step.weights.squeeze().cpu().numpy() + ax.plot(weights, linewidth=2, marker="o", markersize=4, color="red") + ax.set_xlabel("Time Index", fontsize=12) + ax.set_ylabel("Weight Value", fontsize=12) + ax.set_title("Prefix Attention Weights (Last Step)", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.1) + + plt.tight_layout() + + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"Debug summary saved to {save_path}") + + if show: + plt.show() + else: + plt.close(fig) + + return fig + + @staticmethod + def plot_correction_heatmap( + tracker: Tracker, + save_path: str | None = None, + show: bool = False, + figsize: tuple[int, int] = (14, 8), + max_dims: int = 6, + ) -> Figure: + """Create a heatmap showing correction values across steps and action dimensions. + + Args: + tracker (Tracker): Tracker with recorded steps. + save_path (str | None): Path to save the figure. + show (bool): Whether to display the figure. + figsize (tuple[int, int]): Figure size in inches. + max_dims (int): Maximum number of action dimensions to visualize. + + Returns: + Figure: The matplotlib figure object. + """ + if not tracker.enabled or len(tracker) == 0: + print("Tracker is disabled or has no recorded steps.") + return None + + steps = tracker.get_all_steps() + + # Collect corrections across steps (shape: [num_steps, time, action_dim]) + corrections = [step.correction for step in steps if step.correction is not None] + if not corrections: + print("No corrections found in debug steps.") + return None + + # Stack corrections: [num_steps, time, action_dim] + # Take mean over time dimension and limit action dims + corrections_stacked = torch.stack(corrections) # [num_steps, batch, time, action_dim] + corrections_mean = corrections_stacked.mean(dim=(1, 2)) # [num_steps, action_dim] + + # Limit to max_dims + corrections_mean = corrections_mean[:, :max_dims].cpu().numpy() + + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(corrections_mean.T, aspect="auto", cmap="RdBu_r", interpolation="nearest") + + ax.set_xlabel("Step Index", fontsize=12) + ax.set_ylabel("Action Dimension", fontsize=12) + ax.set_title("Correction Values Heatmap (averaged over time)", fontsize=14, fontweight="bold") + + # Colorbar + cbar = plt.colorbar(im, ax=ax) + cbar.set_label("Correction Value", fontsize=12) + + plt.tight_layout() + + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"Correction heatmap saved to {save_path}") + + if show: + plt.show() + else: + plt.close(fig) + + return fig + + @staticmethod + def plot_step_by_step_comparison( + tracker: Tracker, + step_idx: int = -1, + save_path: str | None = None, + show: bool = False, + figsize: tuple[int, int] = (18, 10), + max_dims: int = 6, + ) -> Figure: + """Plot detailed comparison for a single denoising step. + + Args: + tracker (Tracker): Tracker with recorded steps. + step_idx (int): Step index to visualize (-1 for last step). + save_path (str | None): Path to save the figure. + show (bool): Whether to display the figure. + figsize (tuple[int, int]): Figure size in inches. + max_dims (int): Maximum number of action dimensions to visualize. + + Returns: + Figure: The matplotlib figure object. + """ + if not tracker.enabled or len(tracker) == 0: + print("Tracker is disabled or has no recorded steps.") + return None + + steps = tracker.get_all_steps() + step = steps[step_idx] + + fig, axes = plt.subplots(2, 3, figsize=figsize) + fig.suptitle( + f"Detailed Step Analysis (Step {step.step_idx})", + fontsize=16, + fontweight="bold", + ) + + # Get tensors and squeeze batch dimension + x_t = step.x_t.squeeze(0).cpu().numpy() if step.x_t is not None else None + v_t = step.v_t.squeeze(0).cpu().numpy() if step.v_t is not None else None + x1_t = step.x1_t.squeeze(0).cpu().numpy() if step.x1_t is not None else None + correction = step.correction.squeeze(0).cpu().numpy() if step.correction is not None else None + err = step.err.squeeze(0).cpu().numpy() if step.err is not None else None + weights = step.weights.squeeze().cpu().numpy() if step.weights is not None else None + + # Limit to max_dims + num_dims = min(max_dims, x_t.shape[1] if x_t is not None else 0) + + # Plot 1: x_t (current state) + ax = axes[0, 0] + if x_t is not None: + for dim in range(num_dims): + ax.plot(x_t[:, dim], label=f"Dim {dim}", alpha=0.7) + ax.set_title("x_t (Current State)", fontsize=12, fontweight="bold") + ax.set_xlabel("Time Index") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Plot 2: v_t (velocity) + ax = axes[0, 1] + if v_t is not None: + for dim in range(num_dims): + ax.plot(v_t[:, dim], label=f"Dim {dim}", alpha=0.7) + ax.set_title("v_t (Velocity)", fontsize=12, fontweight="bold") + ax.set_xlabel("Time Index") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Plot 3: x1_t (predicted state) + ax = axes[0, 2] + if x1_t is not None: + for dim in range(num_dims): + ax.plot(x1_t[:, dim], label=f"Dim {dim}", alpha=0.7) + ax.set_title("x1_t (Predicted State)", fontsize=12, fontweight="bold") + ax.set_xlabel("Time Index") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Plot 4: correction + ax = axes[1, 0] + if correction is not None: + for dim in range(num_dims): + ax.plot(correction[:, dim], label=f"Dim {dim}", alpha=0.7) + ax.set_title("Correction", fontsize=12, fontweight="bold") + ax.set_xlabel("Time Index") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Plot 5: error + ax = axes[1, 1] + if err is not None: + for dim in range(num_dims): + ax.plot(err[:, dim], label=f"Dim {dim}", alpha=0.7) + ax.set_title("Error (Weighted)", fontsize=12, fontweight="bold") + ax.set_xlabel("Time Index") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Plot 6: prefix weights + ax = axes[1, 2] + if weights is not None: + ax.plot(weights, linewidth=2, marker="o", markersize=4, color="red") + ax.set_title("Prefix Attention Weights", fontsize=12, fontweight="bold") + ax.set_xlabel("Time Index") + ax.set_ylabel("Weight Value") + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.1) + + plt.tight_layout() + + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"Step-by-step comparison saved to {save_path}") + + if show: + plt.show() + else: + plt.close(fig) + + return fig + + @staticmethod + def print_debug_statistics(tracker: Tracker) -> None: + """Print summary statistics from the tracker. + + Args: + tracker (Tracker): Tracker with recorded steps. + """ + if not tracker.enabled: + print("Tracker is disabled.") + return + + stats = tracker.get_step_stats_summary() + + print("\n" + "=" * 60) + print("RTC Debug Statistics Summary") + print("=" * 60) + print(f"Enabled: {stats['enabled']}") + print(f"Total steps recorded: {stats['total_steps']}") + print(f"Step counter: {stats['step_counter']}") + + if "correction_norms" in stats: + print("\nCorrection Norms:") + for key, value in stats["correction_norms"].items(): + print(f" {key}: {value:.6f}") + + if "error_norms" in stats: + print("\nError Norms:") + for key, value in stats["error_norms"].items(): + print(f" {key}: {value:.6f}") + + if "guidance_weights" in stats: + print("\nGuidance Weights:") + for key, value in stats["guidance_weights"].items(): + print(f" {key}: {value:.6f}") + + print("=" * 60 + "\n") diff --git a/src/lerobot/policies/rtc/latency_tracker.py b/src/lerobot/policies/rtc/latency_tracker.py new file mode 100644 index 000000000..e402cf152 --- /dev/null +++ b/src/lerobot/policies/rtc/latency_tracker.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Latency tracking utilities for Real-Time Chunking (RTC).""" + +from collections import deque + +import numpy as np + + +class LatencyTracker: + """Tracks recent latencies and provides max/percentile queries. + + Args: + maxlen (int | None): Optional sliding window size. If provided, only the + most recent ``maxlen`` latencies are kept. If ``None``, keeps all. + """ + + def __init__(self, maxlen: int = 100): + self._values = deque(maxlen=maxlen) + self.reset() + + def reset(self) -> None: + """Clear all recorded latencies.""" + self._values.clear() + self.max_latency = 0.0 + + def add(self, latency: float) -> None: + """Add a latency sample (seconds).""" + # Ensure numeric and non-negative + val = float(latency) + + if val < 0: + return + self._values.append(val) + self.max_latency = max(self.max_latency, val) + + def __len__(self) -> int: + return len(self._values) + + def max(self) -> float | None: + """Return the maximum latency or None if empty.""" + return self.max_latency + + def percentile(self, q: float) -> float | None: + """Return the q-quantile (q in [0,1]) of recorded latencies or None if empty.""" + if not self._values: + return 0.0 + q = float(q) + if q <= 0.0: + return min(self._values) + if q >= 1.0: + return self.max_latency + vals = np.array(list(self._values), dtype=np.float32) + return float(np.quantile(vals, q)) + + def p95(self) -> float | None: + """Return the 95th percentile latency or None if empty.""" + return self.percentile(0.95) diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py new file mode 100644 index 000000000..041ef7e1e --- /dev/null +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Real-Time Chunking (RTC) implementation for LeRobot. + +Based on Physical Intelligence's Kinetix implementation: +https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214 +""" + +import logging +import math + +import torch +from torch import Tensor + +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.debug_handler import Tracker + +logger = logging.getLogger(__name__) + + +class RTCProcessor: + """Real-Time Chunking processor for action chunking policies. + + This class implements RTC techniques including velocity calculation, + prefix attention, and adaptive chunk processing. + """ + + def __init__(self, rtc_config: RTCConfig): + self.rtc_config = rtc_config + + self.tracker = None + + if rtc_config.debug: + self.tracker = Tracker( + enabled=rtc_config.debug, + maxlen=rtc_config.debug_maxlen, + ) + + # ====================== Tracker Proxy Methods ====================== + def track_debug( + self, + time: float | Tensor, + x_t: Tensor | None = None, + v_t: Tensor | None = None, + x1_t: Tensor | None = None, + correction: Tensor | None = None, + err: Tensor | None = None, + weights: Tensor | None = None, + guidance_weight: float | Tensor | None = None, + inference_delay: int | None = None, + execution_horizon: int | None = None, + **metadata, + ) -> None: + """Proxy method to track debug information. + + If tracker is None or disabled, this method does nothing. + Otherwise, it forwards the call to tracker.track(). + """ + if self.tracker is not None: + self.tracker.track( + time=time, + x_t=x_t, + v_t=v_t, + x1_t=x1_t, + correction=correction, + err=err, + weights=weights, + guidance_weight=guidance_weight, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + **metadata, + ) + + def get_tracker_stats(self) -> dict | None: + """Get tracker statistics summary. + + Returns None if tracker is disabled or None. + """ + if self.tracker is not None: + return self.tracker.get_step_stats_summary() + return None + + def get_all_debug_steps(self) -> list: + """Get all debug steps from tracker. + + Returns empty list if tracker is disabled or None. + """ + if self.tracker is not None: + return self.tracker.get_all_steps() + return [] + + def get_recent_debug_steps(self, n: int = 1) -> list: + """Get recent debug steps from tracker. + + Returns empty list if tracker is disabled or None. + """ + if self.tracker is not None: + return self.tracker.get_recent_steps(n) + return [] + + def is_debug_enabled(self) -> bool: + """Check if debug tracking is enabled. + + Returns True if tracker exists and is enabled. + """ + return self.tracker is not None and self.tracker.enabled + + def reset_tracker(self) -> None: + """Reset the tracker, clearing all recorded steps. + + Does nothing if tracker is None. + """ + if self.tracker is not None: + self.tracker.reset() + + def get_tracker_length(self) -> int: + """Get the number of recorded debug steps. + + Returns 0 if tracker is disabled or None. + """ + if self.tracker is not None: + return len(self.tracker) + return 0 + + # ====================== End Tracker Proxy Methods ====================== + + def denoise_step( + self, + x_t, + prev_chunk_left_over, + inference_delay, + time, + original_denoise_step_partial, + execution_horizon=None, + ) -> Tensor: + """RTC guidance wrapper around an existing denoiser. + + This method wraps an original denoising callable that only takes ``x_t`` and + returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking + (RTC) prefix guidance using the leftover prefix from the previous chunk. + + Args: + x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``. + prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous + chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance + is applied and the method returns ``v_t`` from the original denoiser. + inference_delay (int): Number of timesteps from the prefix to use for guidance. + time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be + broadcastable with ``x_t``. + original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that + computes the base denoised velocity given only ``x_t``. + execution_horizon (int | None): Horizon used to build prefix weights. If + ``None``, defaults to ``self.rtc_config.execution_horizon``. + + Returns: + Tensor: Guided velocity with the same shape as ``v_t``. + + Notes: + - If inputs are 2D, a batch dimension is temporarily added and removed at the end. + - If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is + right-padded with zeros to match ``T``. + - Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)`` + and broadcast to ``(B, T, A)``. + - Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and + ``error = (prev_chunk_left_over - x1_t) * weights``. + - The final guidance weight is clamped by ``max_guidance_weight`` from the config. + + Reference: + https://www.physicalintelligence.company/download/real_time_chunking.pdf + """ + + # In the original implementation, the time goes from 0 to 1 and + # In our implementation, the time goes from 1 to 0 + # So we need to invert the time + tau = 1 - time + + x_t = x_t.clone().detach() + + if prev_chunk_left_over is None: + # First step, no guidance - return v_t + v_t = original_denoise_step_partial(x_t) + return v_t + + squeezed = False + if len(x_t.shape) < 3: + # Add batch dimension + x_t = x_t.unsqueeze(0) + squeezed = True + + if len(prev_chunk_left_over.shape) < 3: + # Add batch dimension + prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0) + + if execution_horizon is None: + execution_horizon = self.rtc_config.execution_horizon + + # If the previous action chunk is to short then it doesn't make sense to use long execution horizon + # because there is nothing to merge + if execution_horizon > prev_chunk_left_over.shape[1]: + execution_horizon = prev_chunk_left_over.shape[1] + + batch_size = x_t.shape[0] + action_chunk_size = x_t.shape[1] + action_dim = x_t.shape[2] + + if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim: + padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device) + padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over + prev_chunk_left_over = padded + + assert prev_chunk_left_over.shape == x_t.shape, ( + "The padded previous chunk must be the same size as the input tensor" + ) + + weights = ( + self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size) + .to(x_t.device) + .unsqueeze(0) + .unsqueeze(-1) + ) + + with torch.enable_grad(): + v_t = original_denoise_step_partial(x_t) + x_t.requires_grad_(True) + + x1_t = x_t - time * v_t # noqa: N806 + err = (prev_chunk_left_over - x1_t) * weights + grad_outputs = err.clone().detach() + correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] + + max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) + squared_one_minus_tau = (1 - tau) ** 2 + inv_r2 = (squared_one_minus_tau + tau**2) / (squared_one_minus_tau) + c = torch.nan_to_num((1 - tau) / tau, posinf=max_guidance_weight) + guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight) + guidance_weight = torch.minimum(guidance_weight, max_guidance_weight) + + result = v_t - guidance_weight * correction + + # Remove the batch dimension if it was added + if squeezed: + result = result.squeeze(0) + correction = correction.squeeze(0) + x1_t = x1_t.squeeze(0) + err = err.squeeze(0) + + # Record debug information (all params except x_t which is recorded externally) + self.track_debug( + time=time, + v_t=v_t, + x1_t=x1_t, + correction=correction, + err=err, + weights=weights, + guidance_weight=guidance_weight, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + return result + + def get_prefix_weights(self, start, end, total): + start = min(start, end) + + if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS: + weights = torch.zeros(total) + weights[:start] = 1.0 + elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES: + weights = torch.ones(total) + weights[end:] = 0.0 + elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR: + lin_weights = self._linweights(start, end, total) + weights = self._add_trailing_zeros(lin_weights, total, end) + weights = self._add_leading_ones(weights, start, total) + elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP: + lin_weights = self._linweights(start, end, total) + lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1) + weights = self._add_trailing_zeros(lin_weights, total, end) + weights = self._add_leading_ones(weights, start, total) + + return weights + + def _linweights(self, start, end, total): + skip_steps_at_end = max(total - end, 0) + + linspace_steps = total - skip_steps_at_end - start + + if end <= start or linspace_steps <= 0: + return torch.tensor([]) + + return torch.linspace(1, 0, linspace_steps + 2)[1:-1] + + def _add_trailing_zeros(self, weights, total, end): + zeros_len = total - end + + if zeros_len <= 0: + return weights + + zeros = torch.zeros(zeros_len) + return torch.cat([weights, zeros]) + + def _add_leading_ones(self, weights, start, total): + ones_len = min(start, total) + + if ones_len <= 0: + return weights + + ones = torch.ones(ones_len) + return torch.cat([ones, weights]) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 6e54d3ea5..dd49a45f7 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -55,11 +55,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") import math from collections import deque +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer +from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel from lerobot.policies.utils import ( @@ -68,6 +72,9 @@ from lerobot.policies.utils import ( from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.utils.utils import get_safe_dtype +# Make plot_waypoints easily accessible +plot_waypoints = RTCDebugVisualizer.plot_waypoints + def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -232,8 +239,8 @@ class SmolVLAPolicy(PreTrainedPolicy): super().__init__(config) config.validate_features() self.config = config - - self.model = VLAFlowMatching(config) + self.init_rtc_processor() + self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor) self.reset() def reset(self): @@ -242,10 +249,27 @@ class SmolVLAPolicy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } + def init_rtc_processor(self, verbose: bool = False): + """Initialize RTC processor with optional verbose logging. + + Args: + verbose: Enable verbose debug logging in RTCProcessor (currently unused) + """ + self.rtc_processor = None + + if self.config.rtc_config is not None and self.config.rtc_config.enabled: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + # In case of calling init_rtc_processor after the model is created + # We need to set the rtc_processor to the model + # During the normal initialization process the model is not created yet + if self.model is not None: + self.model.rtc_processor = self.rtc_processor + def get_optim_params(self) -> dict: return self.parameters() - def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: # TODO: Check if this for loop is needed. # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch # In the case of offline inference, we have the action in the batch @@ -260,7 +284,9 @@ class SmolVLAPolicy(PreTrainedPolicy): lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) + actions = self.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs + ) # Unpad actions original_action_dim = self.config.action_feature.shape[0] @@ -278,30 +304,33 @@ class SmolVLAPolicy(PreTrainedPolicy): return batch @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: self.eval() batch = self._prepare_batch(batch) self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - actions = self._get_action_chunk(batch, noise) + actions = self._get_action_chunk(batch, noise, **kwargs) return actions @torch.no_grad() - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: """Select a single action given environment observations. This method wraps `select_actions` in order to return one action at a time for execution in the environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ + + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + self.eval() batch = self._prepare_batch(batch) self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._queues[ACTION]) == 0: + if self._check_get_actions_condition(): actions = self._get_action_chunk(batch, noise) # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -310,6 +339,12 @@ class SmolVLAPolicy(PreTrainedPolicy): return self._queues[ACTION].popleft() + def _check_get_actions_condition(self) -> bool: + return len(self._queues[ACTION]) == 0 + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: @@ -471,7 +506,7 @@ class VLAFlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config: SmolVLAConfig): + def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config @@ -485,7 +520,6 @@ class VLAFlowMatching(nn.Module): num_vlm_layers=self.config.num_vlm_layers, self_attn_every_n_layers=self.config.self_attn_every_n_layers, expert_width_multiplier=self.config.expert_width_multiplier, - device=self.config.device, ) self.state_proj = nn.Linear( self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size @@ -510,6 +544,12 @@ class VLAFlowMatching(nn.Module): self.add_image_special_tokens = self.config.add_image_special_tokens self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) self.prefix_length = self.config.prefix_length + self.rtc_processor = rtc_processor + + # For visualization of x_t during denoising + self.denoise_step_counter = 0 + self.viz_fig = None + self.viz_axs = None def set_requires_grad(self): for params in self.state_proj.parameters(): @@ -706,11 +746,25 @@ class VLAFlowMatching(nn.Module): losses = F.mse_loss(u_t, v_t, reduction="none") return losses - def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + def sample_actions( + self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs + ) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors) + + Args: + viz_xt_axs: Optional matplotlib axes for plotting x_t trajectories (array of 6 axes) + viz_vt_axs: Optional matplotlib axes for plotting v_t trajectories (array of 6 axes) + viz_x1t_axs: Optional matplotlib axes for plotting x1_t predicted state and error (array of 6 axes) + When RTC is enabled, plots both x1_t (solid line) and error (orange dashed line) + """ bsize = state.shape[0] device = state.device + # Extract visualization axes from kwargs + viz_xt_axs = kwargs.pop("viz_xt_axs", None) + viz_vt_axs = kwargs.pop("viz_vt_axs", None) + viz_x1t_axs = kwargs.pop("viz_x1t_axs", None) + if noise is None: actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) noise = self.sample_noise(actions_shape, device) @@ -734,17 +788,167 @@ class VLAFlowMatching(nn.Module): x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) + correction = None + x1_t = None + error = None + use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None + while time >= -dt / 2: expanded_time = time.expand(bsize) - v_t = self.denoise_step( - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) + + # Define a closure function to properly capture expanded_time + # This avoids the lambda expression (E731) and loop variable binding (B023) issues + def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + return self.denoise_step( + x_t=input_x_t, + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + timestep=current_timestep, + ) + + if self.config.rtc_config is not None and self.config.rtc_config.enabled: + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon) + + v_t = self.rtc_processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=inference_delay, + time=time, + original_denoise_step_partial=denoise_step_partial_call, + execution_horizon=execution_horizon, + ) + else: + v_t = denoise_step_partial_call(x_t) + # Euler step x_t += dt * v_t time += dt + + # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) + if ( + self.config.rtc_config is not None + and self.config.rtc_config.enabled + and correction is not None + ): + self.rtc_processor.track_debug(time=time, x_t=x_t) + + # Visualize x_t using plot_waypoints - accumulate all denoise steps + # Use provided axes or create new ones + if not use_provided_axes: + if self.viz_fig is None: + # Create figure once on first denoise step + self.viz_fig, self.viz_axs = plt.subplots(6, 1, figsize=(12, 12)) + self.viz_v_fig, self.viz_v_axs = plt.subplots(6, 1, figsize=(12, 12)) + xt_axs = self.viz_axs + vt_axs = self.viz_v_axs + else: + xt_axs = viz_xt_axs + vt_axs = viz_vt_axs + + # Define colors for different denoise steps (using a colormap) + colors = plt.cm.viridis(np.linspace(0, 1, self.config.num_steps)) + color = colors[self.denoise_step_counter % len(colors)] + + # Plot this denoise step + plot_waypoints(xt_axs, x_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}") + + # Plot this denoise step + plot_waypoints(vt_axs, v_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}") + + if correction is not None: + plot_waypoints( + vt_axs, + correction, + start_from=0, + color="red", + label=f"Step corr {self.denoise_step_counter}", + ) + + # Plot x1_t if axes provided and RTC is enabled + if viz_x1t_axs is not None and x1_t is not None: + plot_waypoints( + viz_x1t_axs, + x1_t, + start_from=0, + color=color, + label=f"x1_t Step {self.denoise_step_counter}", + ) + + # Plot error on the same axes with different color + if error is not None: + # Use orange color for error + # Handle batch dimension if present + error_chunk = error[0].cpu().numpy() if len(error.shape) == 3 else error.cpu().numpy() + + num_dims = min(error_chunk.shape[-1], 6) + for j in range(num_dims): + viz_x1t_axs[j].plot( + np.arange(0, error_chunk.shape[0]), + error_chunk[:, j], + color="orange", + linestyle="--", + alpha=0.7, + label=f"error Step {self.denoise_step_counter}", + ) + + self.denoise_step_counter += 1 + + # Save visualization of x_t denoise steps (only if using internal figures) + if not use_provided_axes and self.viz_fig is not None: + plt.figure(self.viz_fig.number) + + xt_name = "smolvla_x_t_denoise_steps.png" + v_name = "smolvla_v_denoise_steps.png" + + if self.config.rtc_config is not None and self.config.rtc_config.enabled: + xt_name = "smolvla_x_t_with_rtc_denoise_steps.png" + v_name = "smolvla_v_with_rtc_denoise_steps.png" + + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + + if prev_chunk_left_over is not None: + plot_waypoints( + self.viz_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + + plt.savefig(xt_name) + plt.close(self.viz_fig) + + # Reset for next inference + self.viz_fig = None + self.viz_axs = None + self.denoise_step_counter = 0 + + plt.figure(self.viz_v_fig.number) + plt.savefig(v_name) + plt.close(self.viz_v_fig) + + self.viz_v_fig = None + self.viz_v_axs = None + + # Plot ground truth on provided axes if available + if use_provided_axes: + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + if ( + prev_chunk_left_over is not None + and self.config.rtc_config is not None + and self.config.rtc_config.enabled + ): + plot_waypoints( + viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + # Also plot ground truth on x1_t axes if provided + if viz_x1t_axs is not None: + plot_waypoints( + viz_x1t_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + + # Reset counter when using provided axes (for next call) + if use_provided_axes: + self.denoise_step_counter = 0 + return x_t def denoise_step(