#!/usr/bin/env python """ Evaluate Real-Time Chunking (RTC) performance on dataset samples. This script takes two random samples from a dataset: - Uses actions from the first sample as previous chunk - Generates new actions for the second sample with and without RTC It compares action predictions with and without RTC on dataset samples, measuring consistency and ground truth alignment. Usage: # Basic usage uv run python examples/rtc/eval_dataset.py \ --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ --rtc.execution_horizon=8 \ --device=mps # With torch.compile for faster inference (PyTorch 2.0+) uv run python examples/rtc/eval_dataset.py \ --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --use_torch_compile=true \ --torch_compile_mode=max-autotune # With torch.compile for faster inference (PyTorch 2.0+) uv run python examples/rtc/eval_dataset.py \ --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ --rtc.execution_horizon=8 \ --device=cuda \ --use_torch_compile=true \ --torch_compile_mode=reduce-overhead # With custom compile settings uv run python examples/rtc/eval_dataset.py \ --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ --use_torch_compile=true \ --torch_compile_backend=inductor \ --torch_compile_mode=max-autotune """ import logging import os import random from dataclasses import dataclass, field import matplotlib.pyplot as plt import numpy as np import torch from lerobot.configs import parser from lerobot.configs.default import DatasetConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer from lerobot.utils.hub import HubMixin from lerobot.utils.utils import init_logging def set_seed(seed: int): """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @dataclass class RTCEvalConfig(HubMixin): """Configuration for RTC evaluation.""" # Policy configuration policy: PreTrainedConfig | None = None # Dataset configuration dataset: DatasetConfig = field(default_factory=DatasetConfig) # RTC configuration rtc: RTCConfig = field( default_factory=lambda: RTCConfig( enabled=True, execution_horizon=20, max_guidance_weight=10.0, prefix_attention_schedule=RTCAttentionSchedule.EXP, debug=True, debug_maxlen=1000, ) ) # Device configuration device: str | None = field( default=None, metadata={"help": "Device to run on (cuda, cpu, mps, auto)"}, ) # Output configuration output_dir: str = field( default="rtc_debug_output", metadata={"help": "Directory to save debug visualizations"}, ) # Seed configuration seed: int = field( default=42, metadata={"help": "Random seed for reproducibility"}, ) inference_delay: int = field( default=4, metadata={"help": "Inference delay for RTC"}, ) # Torch compile configuration use_torch_compile: bool = field( default=False, metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"}, ) torch_compile_backend: str = field( default="inductor", metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"}, ) torch_compile_mode: str = field( default="default", metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"}, ) def __post_init__(self): # Parse policy path policy_path = parser.get_path_arg("policy") if policy_path: cli_overrides = parser.get_cli_overrides("policy") self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path else: raise ValueError("Policy path is required (--policy.path)") # Auto-detect device if not specified if self.device is None or self.device == "auto": if torch.cuda.is_available(): self.device = "cuda" elif torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" logging.info(f"Auto-detected device: {self.device}") @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" return ["policy"] class RTCEvaluator: """Evaluator for RTC on dataset samples.""" def __init__(self, cfg: RTCEvalConfig): self.cfg = cfg self.device = cfg.device # Load policy logging.info(f"Loading policy from {cfg.policy.pretrained_path}") policy_class = get_policy_class(cfg.policy.type) self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path) self.policy = self.policy.to(self.device) self.policy.eval() # Configure RTC cfg.rtc.enabled = True cfg.rtc.debug = True # Enable debug tracking for visualization self.policy.config.rtc_config = cfg.rtc self.policy.init_rtc_processor() # Apply torch.compile if enabled if cfg.use_torch_compile: self._apply_torch_compile() logging.info(f"Policy loaded: {self.policy.name}") logging.info(f"RTC enabled: {cfg.rtc.enabled}") logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}") # Load dataset logging.info(f"Loading dataset: {cfg.dataset.repo_id}") self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") # Create preprocessor/postprocessor self.preprocessor, self.postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, preprocessor_overrides={ "device_processor": {"device": self.device}, }, ) def _apply_torch_compile(self): """Apply torch.compile to the policy model for faster inference.""" try: # Check if torch.compile is available (PyTorch 2.0+) if not hasattr(torch, "compile"): logging.warning( "torch.compile is not available. Requires PyTorch 2.0+. " f"Current version: {torch.__version__}. Skipping compilation." ) return logging.info("Applying torch.compile to policy model...") logging.info(f" Backend: {self.cfg.torch_compile_backend}") logging.info(f" Mode: {self.cfg.torch_compile_mode}") # Compile the policy's model (not the policy itself to preserve methods) if hasattr(self.policy, "model"): original_model = self.policy.model compiled_model = torch.compile( original_model, backend=self.cfg.torch_compile_backend, mode=self.cfg.torch_compile_mode, ) self.policy.model = compiled_model logging.info("✓ Successfully compiled policy.model") else: logging.warning( "Policy does not have a 'model' attribute. " "Attempting to compile entire policy (may not work for all policy types)." ) self.policy = torch.compile( self.policy, backend=self.cfg.torch_compile_backend, mode=self.cfg.torch_compile_mode, ) logging.info("✓ Successfully compiled policy") except Exception as e: logging.error(f"Failed to apply torch.compile: {e}") logging.warning("Continuing without torch.compile") def run_evaluation(self): """Run evaluation on two random dataset samples.""" # Create output directory os.makedirs(self.cfg.output_dir, exist_ok=True) logging.info(f"Output directory: {self.cfg.output_dir}") logging.info("Starting RTC evaluation") logging.info(f"Inference delay: {self.cfg.inference_delay}") data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) loader_iter = iter(data_loader) first_sample = next(loader_iter) second_sample = next(loader_iter) preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_second_sample = self.preprocessor(second_sample) # Don't postprocess the previous chunk prev_chunk_left_over = self.policy.predict_action_chunk( preprocessed_first_sample, )[:, :25, :].squeeze(0) self.policy.rtc_processor.reset_tracker() logging.info("Resetting tracker") # Sample noise (use same noise for both RTC and non-RTC for fair comparison) noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim) noise = self.policy.model.sample_noise(noise_size, self.device) noise_clone = noise.clone() # Generate actions WITHOUT RTC logging.info("Generating actions WITHOUT RTC") self.policy.config.rtc_config.enabled = False with torch.no_grad(): _ = self.policy.predict_action_chunk( preprocessed_second_sample, noise=noise, ) no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps() self.policy.rtc_processor.reset_tracker() # Generate actions WITH RTC logging.info("Generating actions WITH RTC") self.policy.config.rtc_config.enabled = True with torch.no_grad(): _ = self.policy.predict_action_chunk( preprocessed_second_sample, noise=noise_clone, inference_delay=self.cfg.inference_delay, prev_chunk_left_over=prev_chunk_left_over, execution_horizon=self.cfg.rtc.execution_horizon, ) rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps() self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over) logging.info("Evaluation completed successfully") def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over): # Create side-by-side figures for denoising visualization fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)") fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)") fig_x1t, axs_x1t = self._create_figure( "x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)" ) num_steps = self.policy.config.num_steps self._plot_denoising_steps_from_tracker( rtc_tracked_steps, axs_xt[:, 1], # Right column for x_t axs_vt[:, 1], # Right column for v_t axs_corr[:, 1], # Right column for correction axs_x1t[:, 1], # Right column for x1_t num_steps, ) self._plot_denoising_steps_from_tracker( no_rtc_tracked_steps, axs_xt[:, 0], # Left column for x_t axs_vt[:, 0], # Left column for v_t axs_corr[:, 0], # Left column for correction axs_x1t[:, 0], # Left column for x1_t num_steps, ) # Plot ground truth on x_t axes RTCDebugVisualizer.plot_waypoints( axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) # Plot ground truth on x1_t axes RTCDebugVisualizer.plot_waypoints( axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) # Plot ground truth on x_t axes RTCDebugVisualizer.plot_waypoints( axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) RTCDebugVisualizer.plot_waypoints( axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) # Save denoising plots self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png")) self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")) def _create_figure(self, title): fig, axs = plt.subplots(6, 2, figsize=(24, 12)) fig.suptitle(title, fontsize=16) for ax in axs[:, 0]: ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12) for ax in axs[:, 1]: ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12) return fig, axs def _save_figure(self, fig, path): fig.tight_layout() fig.savefig(path, dpi=150) logging.info(f"Saved figure to {path}") plt.close(fig) def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps): """Plot denoising steps from tracker data. Args: tracked_steps: List of DebugStep objects containing debug steps xt_axs: Matplotlib axes for x_t plots (array of 6 axes) vt_axs: Matplotlib axes for v_t plots (array of 6 axes) corr_axs: Matplotlib axes for correction plots (array of 6 axes) x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes) num_steps: Total number of denoising steps for colormap """ logging.info("=" * 80) logging.info(f"Plotting {len(tracked_steps)} steps") debug_steps = tracked_steps if not debug_steps: return # Define colors for different denoise steps (using a colormap) colors = plt.cm.viridis(np.linspace(0, 1, num_steps)) for step_idx, debug_step in enumerate(debug_steps): color = colors[step_idx % len(colors)] # Plot x_t if debug_step.x_t is not None: RTCDebugVisualizer.plot_waypoints( xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}" ) # Plot v_t if debug_step.v_t is not None: RTCDebugVisualizer.plot_waypoints( vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}" ) # Plot correction on separate axes if debug_step.correction is not None: RTCDebugVisualizer.plot_waypoints( corr_axs, debug_step.correction, start_from=0, color=color, label=f"Step {step_idx}", ) # Plot x1_t (predicted state) if x1t_axs is not None and debug_step.x1_t is not None: RTCDebugVisualizer.plot_waypoints( x1t_axs, debug_step.x1_t, start_from=0, color=color, label=f"x1_t Step {step_idx}", ) # Plot error in orange dashed if x1t_axs is not None and debug_step.err is not None: error_chunk = ( debug_step.err[0].cpu().numpy() if len(debug_step.err.shape) == 3 else debug_step.err.cpu().numpy() ) num_dims = min(error_chunk.shape[-1], 6) for j in range(num_dims): x1t_axs[j].plot( np.arange(0, error_chunk.shape[0]), error_chunk[:, j], color="orange", linestyle="--", alpha=0.7, label=f"error Step {step_idx}", ) # Recalculate axis limits after plotting to ensure proper scaling self._rescale_axes(xt_axs) self._rescale_axes(vt_axs) self._rescale_axes(corr_axs) self._rescale_axes(x1t_axs) def _rescale_axes(self, axes): """Rescale axes to show all data with proper margins. Args: axes: Array of matplotlib axes to rescale """ for ax in axes: ax.relim() ax.autoscale_view() # Add 10% margin to y-axis for better visualization ylim = ax.get_ylim() y_range = ylim[1] - ylim[0] if y_range > 0: # Avoid division by zero margin = y_range * 0.1 ax.set_ylim(ylim[0] - margin, ylim[1] + margin) @parser.wrap() def main(cfg: RTCEvalConfig): """Main entry point for RTC evaluation.""" # Set random seed for reproducibility set_seed(cfg.seed) init_logging() logging.info("=" * 80) logging.info("RTC Dataset Evaluation") logging.info(f"Config: {cfg}") logging.info("=" * 80) evaluator = RTCEvaluator(cfg) evaluator.run_evaluation() if __name__ == "__main__": main()