Files
lerobot/examples/rtc/eval_dataset.py
T
Eugene Mironov 2afe107583 Add Real-Time Chunking (RTC) support for flow matching models
Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00

419 lines
15 KiB
Python

#!/usr/bin/env python
"""
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
This script takes two random samples from a dataset:
- Uses actions from the first sample as previous chunk
- Generates new actions for the second sample with and without RTC
It compares action predictions with and without RTC on dataset samples,
measuring consistency and ground truth alignment.
Usage:
python eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps
"""
import logging
import random
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import numpy as np
import torch
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.utils.hub import HubMixin
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(f"Random seed set to: {seed}")
@dataclass
class RTCEvalConfig(HubMixin):
"""Configuration for RTC evaluation."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Dataset configuration
dataset: DatasetConfig = field(default_factory=DatasetConfig)
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=True,
debug_maxlen=1000,
)
)
# Device configuration
device: str | None = field(
default=None,
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
)
# Output configuration
output_dir: str = field(
default="rtc_debug_output",
metadata={"help": "Directory to save debug visualizations"},
)
verbose: bool = field(
default=False,
metadata={"help": "Enable verbose logging"},
)
enable_debug_viz: bool = field(
default=True,
metadata={"help": "Enable debug visualization"},
)
# Seed configuration
seed: int = field(
default=42,
metadata={"help": "Random seed for reproducibility"},
)
inference_delay: int = field(
default=4,
metadata={"help": "Inference delay for RTC"},
)
def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required (--policy.path)")
# Auto-detect device if not specified
if self.device is None or self.device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
logger.info(f"Auto-detected device: {self.device}")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
class RTCEvaluator:
"""Evaluator for RTC on dataset samples."""
def __init__(self, cfg: RTCEvalConfig):
self.cfg = cfg
self.device = cfg.device
# Load policy
logger.info(f"Loading policy from {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type)
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
self.policy = self.policy.to(self.device)
self.policy.eval()
# Configure RTC
cfg.rtc.enabled = True
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor(verbose=cfg.verbose)
logger.info(f"Policy loaded: {self.policy.name}")
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
# Load dataset
logger.info(f"Loading dataset: {cfg.dataset.repo_id}")
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor
self.preprocessor, self.postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides={
"device_processor": {"device": self.device},
},
)
def run_evaluation(self):
"""Run evaluation on two random dataset samples."""
logger.info("Starting RTC evaluation")
logger.info(f"Inference delay: {self.cfg.inference_delay}")
# Get two random samples from the dataset
idx1, idx2 = random.sample(range(len(self.dataset)), 2)
logger.info(f"Selected samples: {idx1}, {idx2}")
# Get first sample - use its actions as prev_chunk
sample1 = self.dataset[idx1]
for key, value in sample1.items():
if isinstance(value, torch.Tensor):
sample1[key] = value.unsqueeze(0).to(self.device)
preprocessed_sample1 = self.preprocessor(sample1)
prev_chunk_left_over = preprocessed_sample1["action"][0, :, :25]
logger.info(f"Using actions from sample {idx1} as previous chunk: shape={prev_chunk_left_over.shape}")
# Get second sample - generate actions for this one
sample2 = self.dataset[idx2]
for key, value in sample2.items():
if isinstance(value, torch.Tensor):
sample2[key] = value.unsqueeze(0).to(self.device)
preprocessed_sample2 = self.preprocessor(sample2)
logger.info(f"Generating actions for sample {idx2}")
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
noise = self.policy.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone()
# Create side-by-side figures for denoising visualization
fig_xt, axs_xt = plt.subplots(6, 2, figsize=(24, 12))
fig_xt.suptitle("x_t Denoising: No RTC (left) vs RTC (right)", fontsize=16)
fig_vt, axs_vt = plt.subplots(6, 2, figsize=(24, 12))
fig_vt.suptitle("v_t Denoising: No RTC (left) vs RTC (right)", fontsize=16)
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
# Generate actions WITHOUT RTC (plot on left column)
logger.info("Generating actions WITHOUT RTC")
self.policy.config.rtc_config.enabled = False
with torch.no_grad():
no_rtc_actions = self.policy.predict_action_chunk(
preprocessed_sample2,
noise=noise,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
)
# Generate actions WITH RTC (plot on right column)
logger.info("Generating actions WITH RTC")
self.policy.config.rtc_config.enabled = True
with torch.no_grad():
rtc_actions = self.policy.predict_action_chunk(
preprocessed_sample2,
noise=noise_clone,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
)
# Set titles for denoising plots
for ax in axs_xt[:, 0]:
ax.set_title("No RTC" if ax == axs_xt[0, 0] else "", fontsize=12)
for ax in axs_xt[:, 1]:
ax.set_title("RTC" if ax == axs_xt[0, 1] else "", fontsize=12)
for ax in axs_vt[:, 0]:
ax.set_title("No RTC" if ax == axs_vt[0, 0] else "", fontsize=12)
for ax in axs_vt[:, 1]:
ax.set_title("RTC" if ax == axs_vt[0, 1] else "", fontsize=12)
for ax in axs_x1t[:, 0]:
ax.set_title("No RTC (N/A)" if ax == axs_x1t[0, 0] else "", fontsize=12)
for ax in axs_x1t[:, 1]:
ax.set_title("RTC" if ax == axs_x1t[0, 1] else "", fontsize=12)
# Save denoising plots
fig_xt.tight_layout()
fig_xt.savefig("denoising_xt_comparison.png", dpi=150)
logger.info("Saved x_t denoising comparison to denoising_xt_comparison.png")
plt.close(fig_xt)
fig_vt.tight_layout()
fig_vt.savefig("denoising_vt_comparison.png", dpi=150)
logger.info("Saved v_t denoising comparison to denoising_vt_comparison.png")
plt.close(fig_vt)
fig_x1t.tight_layout()
fig_x1t.savefig("denoising_x1t_comparison.png", dpi=150)
logger.info("Saved x1_t predicted state & error comparison to denoising_x1t_comparison.png")
plt.close(fig_x1t)
# Create side-by-side comparison: No RTC (left) vs RTC (right)
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
fig.suptitle("Final Action Comparison: No RTC (left) vs RTC (right)", fontsize=16)
# Plot on left column (No RTC)
self._plot_actions(
axs[:, 0],
prev_chunk_left_over[0].cpu().numpy(),
no_rtc_actions[0].cpu().numpy(),
"No RTC",
)
# Plot on right column (RTC)
self._plot_actions(
axs[:, 1],
prev_chunk_left_over[0].cpu().numpy(),
rtc_actions[0].detach().cpu().numpy(),
"RTC",
)
plt.tight_layout()
plt.savefig("final_actions_comparison.png", dpi=150)
logger.info("Saved final actions comparison to final_actions_comparison.png")
plt.close(fig)
# Visualize debug information if enabled
if self.cfg.enable_debug_viz and self.policy.rtc_processor is not None:
self._visualize_debug_info()
logger.info("Evaluation completed successfully")
def _plot_actions(self, axs, prev_chunk, predicted_actions, title):
"""Plot actions comparison on given axes."""
# Ensure arrays are 2D
if prev_chunk.ndim == 1:
prev_chunk = prev_chunk.reshape(1, -1)
if predicted_actions.ndim == 1:
predicted_actions = predicted_actions.reshape(1, -1)
for j in range(min(prev_chunk.shape[-1], 6)): # Limit to 6 dimensions
axs[j].plot(
np.arange(prev_chunk.shape[0]),
prev_chunk[:, j],
color="green",
label="Previous Chunk",
)
axs[j].plot(
np.arange(predicted_actions.shape[0]),
predicted_actions[:, j],
color="red" if "RTC" in title else "blue",
label=title,
)
axs[j].set_ylabel("Joint angle", fontsize=14)
axs[j].grid()
axs[j].legend(loc="upper right", fontsize=14)
axs[j].set_title(title if j == 0 else "", fontsize=12)
if j == 2:
axs[j].set_xlabel("Step #", fontsize=16)
def _visualize_debug_info(self):
"""Visualize debug information from the RTC processor."""
import os
# Use proxy method to check if debug is enabled
if not self.policy.rtc_processor.is_debug_enabled():
logger.warning("Debug tracking is disabled. Skipping debug visualization.")
return
# Get tracker length using proxy method
if self.policy.rtc_processor.get_tracker_length() == 0:
logger.warning("No debug steps recorded. Skipping debug visualization.")
return
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logger.info(f"Saving debug visualizations to {self.cfg.output_dir}")
# Still need direct access to tracker for visualization functions
# This is acceptable since RTCDebugVisualizer is part of the RTC package
tracker = self.policy.rtc_processor.tracker
# Print statistics
RTCDebugVisualizer.print_debug_statistics(tracker)
# Plot debug summary
summary_path = os.path.join(self.cfg.output_dir, "debug_summary.png")
RTCDebugVisualizer.plot_debug_summary(
tracker,
save_path=summary_path,
show=False,
)
# Plot correction heatmap
heatmap_path = os.path.join(self.cfg.output_dir, "correction_heatmap.png")
RTCDebugVisualizer.plot_correction_heatmap(
tracker,
save_path=heatmap_path,
show=False,
)
# Plot step-by-step comparison (last step)
step_path = os.path.join(self.cfg.output_dir, "step_comparison_last.png")
RTCDebugVisualizer.plot_step_by_step_comparison(
tracker,
step_idx=-1,
save_path=step_path,
show=False,
)
# Plot step-by-step comparison (first step)
step_path_first = os.path.join(self.cfg.output_dir, "step_comparison_first.png")
if self.policy.rtc_processor.get_tracker_length() > 0:
RTCDebugVisualizer.plot_step_by_step_comparison(
tracker,
step_idx=0,
save_path=step_path_first,
show=False,
)
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
@parser.wrap()
def main(cfg: RTCEvalConfig):
"""Main entry point for RTC evaluation."""
# Set random seed for reproducibility
set_seed(cfg.seed)
logger.info("=" * 80)
logger.info("RTC Dataset Evaluation")
logger.info(f"Config: {cfg}")
logger.info("=" * 80)
evaluator = RTCEvaluator(cfg)
evaluator.run_evaluation()
if __name__ == "__main__":
main()