mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
502 lines
18 KiB
Python
502 lines
18 KiB
Python
#!/usr/bin/env python
|
|
|
|
"""
|
|
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
|
|
|
|
This script takes two random samples from a dataset:
|
|
- Uses actions from the first sample as previous chunk
|
|
- Generates new actions for the second sample with and without RTC
|
|
|
|
It compares action predictions with and without RTC on dataset samples,
|
|
measuring consistency and ground truth alignment.
|
|
|
|
Usage:
|
|
# Basic usage
|
|
uv run python examples/rtc/eval_dataset.py \
|
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
|
--dataset.repo_id=helper2424/check_rtc \
|
|
--rtc.execution_horizon=8 \
|
|
--device=mps
|
|
|
|
# With torch.compile for faster inference (PyTorch 2.0+)
|
|
uv run python examples/rtc/eval_dataset.py \
|
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
|
--dataset.repo_id=helper2424/check_rtc \
|
|
--rtc.execution_horizon=8 \
|
|
--device=mps \
|
|
--use_torch_compile=true \
|
|
--torch_compile_mode=max-autotune
|
|
|
|
# With torch.compile for faster inference (PyTorch 2.0+)
|
|
uv run python examples/rtc/eval_dataset.py \
|
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
|
--dataset.repo_id=helper2424/check_rtc \
|
|
--rtc.execution_horizon=8 \
|
|
--device=cuda \
|
|
--use_torch_compile=true \
|
|
--torch_compile_mode=reduce-overhead
|
|
|
|
# With custom compile settings
|
|
uv run python examples/rtc/eval_dataset.py \
|
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
|
--dataset.repo_id=helper2424/check_rtc \
|
|
--use_torch_compile=true \
|
|
--torch_compile_backend=inductor \
|
|
--torch_compile_mode=max-autotune
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import random
|
|
from dataclasses import dataclass, field
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.default import DatasetConfig
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.configs.types import RTCAttentionSchedule
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
|
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
|
from lerobot.utils.hub import HubMixin
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
|
|
def set_seed(seed: int):
|
|
"""Set random seed for reproducibility."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
if torch.backends.mps.is_available():
|
|
torch.mps.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
@dataclass
|
|
class RTCEvalConfig(HubMixin):
|
|
"""Configuration for RTC evaluation."""
|
|
|
|
# Policy configuration
|
|
policy: PreTrainedConfig | None = None
|
|
|
|
# Dataset configuration
|
|
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
|
|
|
# RTC configuration
|
|
rtc: RTCConfig = field(
|
|
default_factory=lambda: RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=20,
|
|
max_guidance_weight=10.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
debug=True,
|
|
debug_maxlen=1000,
|
|
)
|
|
)
|
|
|
|
# Device configuration
|
|
device: str | None = field(
|
|
default=None,
|
|
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
|
|
)
|
|
|
|
# Output configuration
|
|
output_dir: str = field(
|
|
default="rtc_debug_output",
|
|
metadata={"help": "Directory to save debug visualizations"},
|
|
)
|
|
|
|
# Seed configuration
|
|
seed: int = field(
|
|
default=42,
|
|
metadata={"help": "Random seed for reproducibility"},
|
|
)
|
|
|
|
inference_delay: int = field(
|
|
default=4,
|
|
metadata={"help": "Inference delay for RTC"},
|
|
)
|
|
|
|
# Torch compile configuration
|
|
use_torch_compile: bool = field(
|
|
default=False,
|
|
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
|
)
|
|
|
|
torch_compile_backend: str = field(
|
|
default="inductor",
|
|
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
|
)
|
|
|
|
torch_compile_mode: str = field(
|
|
default="default",
|
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
# Parse policy path
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
else:
|
|
raise ValueError("Policy path is required (--policy.path)")
|
|
|
|
# Auto-detect device if not specified
|
|
if self.device is None or self.device == "auto":
|
|
if torch.cuda.is_available():
|
|
self.device = "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
self.device = "mps"
|
|
else:
|
|
self.device = "cpu"
|
|
logging.info(f"Auto-detected device: {self.device}")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
|
|
class RTCEvaluator:
|
|
"""Evaluator for RTC on dataset samples."""
|
|
|
|
def __init__(self, cfg: RTCEvalConfig):
|
|
self.cfg = cfg
|
|
self.device = cfg.device
|
|
|
|
# Load policy
|
|
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
|
policy_class = get_policy_class(cfg.policy.type)
|
|
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
|
self.policy = self.policy.to(self.device)
|
|
self.policy.eval()
|
|
|
|
# Configure RTC
|
|
cfg.rtc.enabled = True
|
|
cfg.rtc.debug = True # Enable debug tracking for visualization
|
|
self.policy.config.rtc_config = cfg.rtc
|
|
self.policy.init_rtc_processor()
|
|
|
|
# Apply torch.compile if enabled
|
|
if cfg.use_torch_compile:
|
|
self._apply_torch_compile()
|
|
|
|
logging.info(f"Policy loaded: {self.policy.name}")
|
|
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
|
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
|
|
|
# Load dataset
|
|
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
|
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
|
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
|
|
|
# Create preprocessor/postprocessor
|
|
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=cfg.policy.pretrained_path,
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": self.device},
|
|
},
|
|
)
|
|
|
|
def _apply_torch_compile(self):
|
|
"""Apply torch.compile to the policy model for faster inference."""
|
|
try:
|
|
# Check if torch.compile is available (PyTorch 2.0+)
|
|
if not hasattr(torch, "compile"):
|
|
logging.warning(
|
|
"torch.compile is not available. Requires PyTorch 2.0+. "
|
|
f"Current version: {torch.__version__}. Skipping compilation."
|
|
)
|
|
return
|
|
|
|
logging.info("Applying torch.compile to policy model...")
|
|
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
|
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
|
|
|
# Compile the policy's model (not the policy itself to preserve methods)
|
|
if hasattr(self.policy, "model"):
|
|
original_model = self.policy.model
|
|
compiled_model = torch.compile(
|
|
original_model,
|
|
backend=self.cfg.torch_compile_backend,
|
|
mode=self.cfg.torch_compile_mode,
|
|
)
|
|
self.policy.model = compiled_model
|
|
logging.info("✓ Successfully compiled policy.model")
|
|
else:
|
|
logging.warning(
|
|
"Policy does not have a 'model' attribute. "
|
|
"Attempting to compile entire policy (may not work for all policy types)."
|
|
)
|
|
self.policy = torch.compile(
|
|
self.policy,
|
|
backend=self.cfg.torch_compile_backend,
|
|
mode=self.cfg.torch_compile_mode,
|
|
)
|
|
logging.info("✓ Successfully compiled policy")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to apply torch.compile: {e}")
|
|
logging.warning("Continuing without torch.compile")
|
|
|
|
def run_evaluation(self):
|
|
"""Run evaluation on two random dataset samples."""
|
|
# Create output directory
|
|
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
|
logging.info(f"Output directory: {self.cfg.output_dir}")
|
|
|
|
logging.info("Starting RTC evaluation")
|
|
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
|
|
|
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
|
loader_iter = iter(data_loader)
|
|
first_sample = next(loader_iter)
|
|
second_sample = next(loader_iter)
|
|
|
|
preprocessed_first_sample = self.preprocessor(first_sample)
|
|
preprocessed_second_sample = self.preprocessor(second_sample)
|
|
|
|
# Don't postprocess the previous chunk
|
|
prev_chunk_left_over = self.policy.predict_action_chunk(
|
|
preprocessed_first_sample,
|
|
)[:, :25, :].squeeze(0)
|
|
|
|
self.policy.rtc_processor.reset_tracker()
|
|
|
|
logging.info("Resetting tracker")
|
|
|
|
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
|
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
|
noise = self.policy.model.sample_noise(noise_size, self.device)
|
|
noise_clone = noise.clone()
|
|
|
|
# Generate actions WITHOUT RTC
|
|
logging.info("Generating actions WITHOUT RTC")
|
|
self.policy.config.rtc_config.enabled = False
|
|
with torch.no_grad():
|
|
_ = self.policy.predict_action_chunk(
|
|
preprocessed_second_sample,
|
|
noise=noise,
|
|
)
|
|
|
|
no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps()
|
|
self.policy.rtc_processor.reset_tracker()
|
|
|
|
# Generate actions WITH RTC
|
|
logging.info("Generating actions WITH RTC")
|
|
self.policy.config.rtc_config.enabled = True
|
|
with torch.no_grad():
|
|
_ = self.policy.predict_action_chunk(
|
|
preprocessed_second_sample,
|
|
noise=noise_clone,
|
|
inference_delay=self.cfg.inference_delay,
|
|
prev_chunk_left_over=prev_chunk_left_over,
|
|
execution_horizon=self.cfg.rtc.execution_horizon,
|
|
)
|
|
|
|
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
|
|
|
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
|
logging.info("Evaluation completed successfully")
|
|
|
|
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
|
# Create side-by-side figures for denoising visualization
|
|
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
|
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
|
fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
|
|
fig_x1t, axs_x1t = self._create_figure(
|
|
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
|
)
|
|
|
|
num_steps = self.policy.config.num_steps
|
|
self._plot_denoising_steps_from_tracker(
|
|
rtc_tracked_steps,
|
|
axs_xt[:, 1], # Right column for x_t
|
|
axs_vt[:, 1], # Right column for v_t
|
|
axs_corr[:, 1], # Right column for correction
|
|
axs_x1t[:, 1], # Right column for x1_t
|
|
num_steps,
|
|
)
|
|
|
|
self._plot_denoising_steps_from_tracker(
|
|
no_rtc_tracked_steps,
|
|
axs_xt[:, 0], # Left column for x_t
|
|
axs_vt[:, 0], # Left column for v_t
|
|
axs_corr[:, 0], # Left column for correction
|
|
axs_x1t[:, 0], # Left column for x1_t
|
|
num_steps,
|
|
)
|
|
|
|
# Plot ground truth on x_t axes
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
)
|
|
|
|
# Plot ground truth on x1_t axes
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
)
|
|
|
|
# Plot ground truth on x_t axes
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
)
|
|
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
)
|
|
|
|
# Save denoising plots
|
|
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
|
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
|
self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
|
|
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
|
|
|
|
def _create_figure(self, title):
|
|
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
|
|
fig.suptitle(title, fontsize=16)
|
|
|
|
for ax in axs[:, 0]:
|
|
ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
|
|
for ax in axs[:, 1]:
|
|
ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
|
|
|
|
return fig, axs
|
|
|
|
def _save_figure(self, fig, path):
|
|
fig.tight_layout()
|
|
fig.savefig(path, dpi=150)
|
|
logging.info(f"Saved figure to {path}")
|
|
plt.close(fig)
|
|
|
|
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps):
|
|
"""Plot denoising steps from tracker data.
|
|
|
|
Args:
|
|
tracked_steps: List of DebugStep objects containing debug steps
|
|
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
|
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
|
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
|
|
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
|
num_steps: Total number of denoising steps for colormap
|
|
"""
|
|
|
|
logging.info("=" * 80)
|
|
logging.info(f"Plotting {len(tracked_steps)} steps")
|
|
|
|
debug_steps = tracked_steps
|
|
if not debug_steps:
|
|
return
|
|
|
|
# Define colors for different denoise steps (using a colormap)
|
|
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
|
|
|
|
for step_idx, debug_step in enumerate(debug_steps):
|
|
color = colors[step_idx % len(colors)]
|
|
|
|
# Plot x_t
|
|
if debug_step.x_t is not None:
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
|
|
)
|
|
|
|
# Plot v_t
|
|
if debug_step.v_t is not None:
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
|
|
)
|
|
|
|
# Plot correction on separate axes
|
|
if debug_step.correction is not None:
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
corr_axs,
|
|
debug_step.correction,
|
|
start_from=0,
|
|
color=color,
|
|
label=f"Step {step_idx}",
|
|
)
|
|
|
|
# Plot x1_t (predicted state)
|
|
if x1t_axs is not None and debug_step.x1_t is not None:
|
|
RTCDebugVisualizer.plot_waypoints(
|
|
x1t_axs,
|
|
debug_step.x1_t,
|
|
start_from=0,
|
|
color=color,
|
|
label=f"x1_t Step {step_idx}",
|
|
)
|
|
|
|
# Plot error in orange dashed
|
|
if x1t_axs is not None and debug_step.err is not None:
|
|
error_chunk = (
|
|
debug_step.err[0].cpu().numpy()
|
|
if len(debug_step.err.shape) == 3
|
|
else debug_step.err.cpu().numpy()
|
|
)
|
|
|
|
num_dims = min(error_chunk.shape[-1], 6)
|
|
for j in range(num_dims):
|
|
x1t_axs[j].plot(
|
|
np.arange(0, error_chunk.shape[0]),
|
|
error_chunk[:, j],
|
|
color="orange",
|
|
linestyle="--",
|
|
alpha=0.7,
|
|
label=f"error Step {step_idx}",
|
|
)
|
|
|
|
# Recalculate axis limits after plotting to ensure proper scaling
|
|
self._rescale_axes(xt_axs)
|
|
self._rescale_axes(vt_axs)
|
|
self._rescale_axes(corr_axs)
|
|
self._rescale_axes(x1t_axs)
|
|
|
|
def _rescale_axes(self, axes):
|
|
"""Rescale axes to show all data with proper margins.
|
|
|
|
Args:
|
|
axes: Array of matplotlib axes to rescale
|
|
"""
|
|
for ax in axes:
|
|
ax.relim()
|
|
ax.autoscale_view()
|
|
|
|
# Add 10% margin to y-axis for better visualization
|
|
ylim = ax.get_ylim()
|
|
y_range = ylim[1] - ylim[0]
|
|
if y_range > 0: # Avoid division by zero
|
|
margin = y_range * 0.1
|
|
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
|
|
|
|
|
|
@parser.wrap()
|
|
def main(cfg: RTCEvalConfig):
|
|
"""Main entry point for RTC evaluation."""
|
|
# Set random seed for reproducibility
|
|
set_seed(cfg.seed)
|
|
|
|
init_logging()
|
|
|
|
logging.info("=" * 80)
|
|
logging.info("RTC Dataset Evaluation")
|
|
logging.info(f"Config: {cfg}")
|
|
logging.info("=" * 80)
|
|
|
|
evaluator = RTCEvaluator(cfg)
|
|
evaluator.run_evaluation()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|