mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Add Real-Time Chunking (RTC) support for flow matching models
Implement Real-Time Chunking (RTC) for action chunking policies using flow matching denoising. RTC enables smooth action transitions between consecutive chunks by using prefix guidance during denoising. Key features: - RTCProcessor class with denoise_step method for RTC guidance - Tracker system for debug tracking using time-based dictionary storage - RTCDebugVisualizer with comprehensive visualization utilities - Integration with SmolVLA policy for flow matching models - Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP) - Configurable execution horizon and max guidance weight - Example scripts for dataset evaluation and real-time control Technical details: - Uses autograd-based gradient computation for RTC corrections - Time-based tracking eliminates duplicate step issues - Proxy methods in RTCProcessor for cleaner API - Full integration with LeRobot's policy and dataset systems Files added/modified: - src/lerobot/configs/types.py: Add RTCAttentionSchedule enum - src/lerobot/policies/rtc/: Core RTC implementation - configuration_rtc.py: RTC configuration - modeling_rtc.py: RTCProcessor with denoise_step - debug_handler.py: Tracker for debug information - debug_visualizer.py: Visualization utilities - src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration - examples/rtc/: Example scripts and evaluation tools 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,281 @@
|
||||
# Real-Time Chunking (RTC) Examples
|
||||
|
||||
This directory contains examples and evaluation scripts for Real-Time Chunking (RTC), a technique for improving action chunking policies in real-time robot control.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-Time Chunking addresses the challenge of maintaining consistency and reactivity when using action chunking policies with non-negligible inference latency. It uses a guidance technique during diffusion sampling to blend new action predictions with previously planned actions.
|
||||
|
||||
**Key Benefits:**
|
||||
|
||||
- Maintains consistency between consecutive action chunks
|
||||
- Reduces jitter and improves smoothness
|
||||
- Adapts to inference delays dynamically
|
||||
|
||||
**Reference:** [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||
|
||||
## Scripts
|
||||
|
||||
### 1. `real_time_chunking_evaluate.py`
|
||||
|
||||
Real-time evaluation on physical robots or simulation environments.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Run policy with RTC on real robot or simulation
|
||||
- Compare RTC vs non-RTC actions in real-time
|
||||
- Multi-threaded action execution and inference
|
||||
- Support for torch.compile() optimization
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# With real robot
|
||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--task="pick up the cup"
|
||||
|
||||
# With simulation environment
|
||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--env.type=pusht \
|
||||
--duration=60.0
|
||||
|
||||
# Disable verbose comparison (faster)
|
||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--verbose_rtc_comparison=false
|
||||
|
||||
# With policy compilation (CUDA only, not MPS)
|
||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--compile_policy=true \
|
||||
--compile_mode=max-autotune
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- `--policy.path`: Path to pretrained policy
|
||||
- `--robot.type` or `--env.type`: Robot or environment to use
|
||||
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
|
||||
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
|
||||
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||
- `--verbose_rtc_comparison`: Enable detailed RTC comparison logging (default: true)
|
||||
- `--duration`: How long to run (seconds, default: 30.0)
|
||||
- `--fps`: Action execution frequency (Hz, default: 10.0)
|
||||
|
||||
### 2. `evaluate_rtc_on_dataset.py`
|
||||
|
||||
Offline evaluation on dataset samples to measure RTC effectiveness.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Evaluate RTC on dataset without running robot
|
||||
- Compare RTC vs non-RTC predictions
|
||||
- Measure consistency and ground truth alignment
|
||||
- Simulate different inference delays
|
||||
- Save detailed metrics to JSON
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--num_iterations=100
|
||||
|
||||
# Simulate inference delay (every 3rd step)
|
||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--num_iterations=200 \
|
||||
--skip_steps=3
|
||||
|
||||
# Custom RTC configuration
|
||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--num_iterations=100 \
|
||||
--rtc.execution_horizon=12 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR
|
||||
|
||||
# Save results to file
|
||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--num_iterations=100 \
|
||||
--output_path=results/rtc_evaluation.json
|
||||
|
||||
# Verbose mode with detailed logging
|
||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--num_iterations=50 \
|
||||
--verbose=true
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- `--policy.path`: Path to pretrained policy
|
||||
- `--dataset.repo_id`: Dataset to evaluate on
|
||||
- `--num_iterations`: Number of samples to evaluate (default: 100)
|
||||
- `--skip_steps`: Steps to skip between inferences, simulates inference delay (default: 1)
|
||||
- `--start_episode`: Episode to start from (default: 0)
|
||||
- `--output_path`: Path to save results JSON
|
||||
- `--verbose`: Enable detailed per-sample logging
|
||||
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||
|
||||
**Metrics Reported:**
|
||||
|
||||
- **RTC vs Ground Truth MSE**: How close RTC predictions are to actual actions
|
||||
- **No-RTC vs Ground Truth MSE**: Baseline without RTC
|
||||
- **RTC Improvement**: Absolute and relative improvement over baseline
|
||||
- **RTC Consistency**: How well RTC maintains consistency in prefix region
|
||||
- Prefix MSE
|
||||
- Mean/Max error in overlap region
|
||||
|
||||
### 3. `run_dataset_evaluation.sh`
|
||||
|
||||
Convenience script with multiple evaluation scenarios.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Edit the script to set your policy and dataset
|
||||
# Then run all examples:
|
||||
./examples/rtc/run_dataset_evaluation.sh
|
||||
|
||||
# Or run individual examples from the script
|
||||
```
|
||||
|
||||
## Understanding RTC Parameters
|
||||
|
||||
### `execution_horizon`
|
||||
|
||||
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
|
||||
|
||||
**Typical values:** 8-12 steps
|
||||
|
||||
### `max_guidance_weight`
|
||||
|
||||
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
|
||||
|
||||
**Typical values:** 1.0-10.0
|
||||
|
||||
### `prefix_attention_schedule`
|
||||
|
||||
How to weight consistency across the overlap region:
|
||||
|
||||
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
|
||||
- `ONES`: Full weight across entire execution_horizon
|
||||
- `LINEAR`: Linear decay from inference_delay to execution_horizon
|
||||
- `EXP`: Exponential decay (recommended)
|
||||
|
||||
**Recommended:** `EXP`
|
||||
|
||||
### `skip_steps` (evaluation only)
|
||||
|
||||
Simulates inference delay by evaluating every N-th step. This helps understand how RTC performs with realistic delays.
|
||||
|
||||
**Example:** `skip_steps=3` means policy infers every 3 steps, simulating 3x action execution frequency vs inference frequency.
|
||||
|
||||
## Output Format (Dataset Evaluation)
|
||||
|
||||
When using `--output_path`, results are saved in JSON format:
|
||||
|
||||
```json
|
||||
{
|
||||
"summary": {
|
||||
"rtc_vs_ground_truth_mse": {
|
||||
"mean": 0.00123,
|
||||
"std": 0.00045,
|
||||
"min": 0.00012,
|
||||
"max": 0.00456
|
||||
},
|
||||
"improvement": {
|
||||
"absolute": 0.00034,
|
||||
"relative_percent": 12.5
|
||||
},
|
||||
...
|
||||
},
|
||||
"config": {
|
||||
"num_iterations": 100,
|
||||
"skip_steps": 3,
|
||||
"execution_horizon": 10,
|
||||
...
|
||||
},
|
||||
"detailed_results": [
|
||||
{
|
||||
"sample_idx": 0,
|
||||
"rtc_vs_ground_truth_mse": 0.00112,
|
||||
"no_rtc_vs_ground_truth_mse": 0.00145,
|
||||
...
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
1. **Start with dataset evaluation** to understand RTC behavior before running on robot
|
||||
2. **Use verbose mode** for debugging unexpected behavior
|
||||
3. **Tune execution_horizon** based on your inference latency and action frequency
|
||||
4. **Monitor consistency metrics** - very low consistency might indicate execution_horizon is too small
|
||||
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### High RTC vs No-RTC difference but no improvement
|
||||
|
||||
- Try reducing `max_guidance_weight`
|
||||
- Check if `execution_horizon` is too large
|
||||
|
||||
### Poor consistency metrics
|
||||
|
||||
- Increase `execution_horizon`
|
||||
- Check that `skip_steps` is not larger than your action chunk size
|
||||
- Verify episodes are being reset correctly
|
||||
|
||||
### RTC worse than No-RTC
|
||||
|
||||
- RTC may not help if inference is faster than action execution
|
||||
- Try different `prefix_attention_schedule`
|
||||
- Ensure `execution_horizon` matches your use case
|
||||
|
||||
## Examples Results
|
||||
|
||||
Example output from dataset evaluation:
|
||||
|
||||
```
|
||||
================================================================================
|
||||
EVALUATION SUMMARY
|
||||
================================================================================
|
||||
|
||||
Ground Truth Alignment:
|
||||
RTC MSE: 0.001234 ± 0.000456
|
||||
No-RTC MSE: 0.001567 ± 0.000512
|
||||
|
||||
RTC Improvement:
|
||||
Absolute: 0.000333
|
||||
Relative: 21.23%
|
||||
|
||||
RTC vs No-RTC Difference:
|
||||
MSE: 0.000112 ± 0.000034
|
||||
|
||||
RTC Consistency (Prefix Region):
|
||||
MSE: 0.000089 ± 0.000023
|
||||
Mean Error: 0.007654 ± 0.002341
|
||||
Max Error: 0.023456 ± 0.008765
|
||||
```
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
|
||||
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
|
||||
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||
@@ -0,0 +1,418 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
|
||||
|
||||
This script takes two random samples from a dataset:
|
||||
- Uses actions from the first sample as previous chunk
|
||||
- Generates new actions for the second sample with and without RTC
|
||||
|
||||
It compares action predictions with and without RTC on dataset samples,
|
||||
measuring consistency and ground truth alignment.
|
||||
|
||||
Usage:
|
||||
python eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""Set random seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
logger.info(f"Random seed set to: {seed}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCEvalConfig(HubMixin):
|
||||
"""Configuration for RTC evaluation."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Dataset configuration
|
||||
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=True,
|
||||
debug_maxlen=1000,
|
||||
)
|
||||
)
|
||||
|
||||
# Device configuration
|
||||
device: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
|
||||
)
|
||||
|
||||
# Output configuration
|
||||
output_dir: str = field(
|
||||
default="rtc_debug_output",
|
||||
metadata={"help": "Directory to save debug visualizations"},
|
||||
)
|
||||
verbose: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable verbose logging"},
|
||||
)
|
||||
enable_debug_viz: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable debug visualization"},
|
||||
)
|
||||
|
||||
# Seed configuration
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed for reproducibility"},
|
||||
)
|
||||
|
||||
inference_delay: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Inference delay for RTC"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse policy path
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required (--policy.path)")
|
||||
|
||||
# Auto-detect device if not specified
|
||||
if self.device is None or self.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
logger.info(f"Auto-detected device: {self.device}")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
class RTCEvaluator:
|
||||
"""Evaluator for RTC on dataset samples."""
|
||||
|
||||
def __init__(self, cfg: RTCEvalConfig):
|
||||
self.cfg = cfg
|
||||
self.device = cfg.device
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self.policy.eval()
|
||||
|
||||
# Configure RTC
|
||||
cfg.rtc.enabled = True
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor(verbose=cfg.verbose)
|
||||
|
||||
logger.info(f"Policy loaded: {self.policy.name}")
|
||||
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||
logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
||||
logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||
|
||||
# Create preprocessor/postprocessor
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": self.device},
|
||||
},
|
||||
)
|
||||
|
||||
def run_evaluation(self):
|
||||
"""Run evaluation on two random dataset samples."""
|
||||
logger.info("Starting RTC evaluation")
|
||||
logger.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||
|
||||
# Get two random samples from the dataset
|
||||
idx1, idx2 = random.sample(range(len(self.dataset)), 2)
|
||||
logger.info(f"Selected samples: {idx1}, {idx2}")
|
||||
|
||||
# Get first sample - use its actions as prev_chunk
|
||||
sample1 = self.dataset[idx1]
|
||||
for key, value in sample1.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
sample1[key] = value.unsqueeze(0).to(self.device)
|
||||
|
||||
preprocessed_sample1 = self.preprocessor(sample1)
|
||||
prev_chunk_left_over = preprocessed_sample1["action"][0, :, :25]
|
||||
logger.info(f"Using actions from sample {idx1} as previous chunk: shape={prev_chunk_left_over.shape}")
|
||||
|
||||
# Get second sample - generate actions for this one
|
||||
sample2 = self.dataset[idx2]
|
||||
for key, value in sample2.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
sample2[key] = value.unsqueeze(0).to(self.device)
|
||||
|
||||
preprocessed_sample2 = self.preprocessor(sample2)
|
||||
logger.info(f"Generating actions for sample {idx2}")
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
||||
noise = self.policy.model.sample_noise(noise_size, self.device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_xt.suptitle("x_t Denoising: No RTC (left) vs RTC (right)", fontsize=16)
|
||||
|
||||
fig_vt, axs_vt = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_vt.suptitle("v_t Denoising: No RTC (left) vs RTC (right)", fontsize=16)
|
||||
|
||||
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
|
||||
|
||||
# Generate actions WITHOUT RTC (plot on left column)
|
||||
logger.info("Generating actions WITHOUT RTC")
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
with torch.no_grad():
|
||||
no_rtc_actions = self.policy.predict_action_chunk(
|
||||
preprocessed_sample2,
|
||||
noise=noise,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
||||
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC (plot on right column)
|
||||
logger.info("Generating actions WITH RTC")
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
with torch.no_grad():
|
||||
rtc_actions = self.policy.predict_action_chunk(
|
||||
preprocessed_sample2,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
|
||||
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
|
||||
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
|
||||
)
|
||||
|
||||
# Set titles for denoising plots
|
||||
for ax in axs_xt[:, 0]:
|
||||
ax.set_title("No RTC" if ax == axs_xt[0, 0] else "", fontsize=12)
|
||||
for ax in axs_xt[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs_xt[0, 1] else "", fontsize=12)
|
||||
|
||||
for ax in axs_vt[:, 0]:
|
||||
ax.set_title("No RTC" if ax == axs_vt[0, 0] else "", fontsize=12)
|
||||
for ax in axs_vt[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs_vt[0, 1] else "", fontsize=12)
|
||||
|
||||
for ax in axs_x1t[:, 0]:
|
||||
ax.set_title("No RTC (N/A)" if ax == axs_x1t[0, 0] else "", fontsize=12)
|
||||
for ax in axs_x1t[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs_x1t[0, 1] else "", fontsize=12)
|
||||
|
||||
# Save denoising plots
|
||||
fig_xt.tight_layout()
|
||||
fig_xt.savefig("denoising_xt_comparison.png", dpi=150)
|
||||
logger.info("Saved x_t denoising comparison to denoising_xt_comparison.png")
|
||||
plt.close(fig_xt)
|
||||
|
||||
fig_vt.tight_layout()
|
||||
fig_vt.savefig("denoising_vt_comparison.png", dpi=150)
|
||||
logger.info("Saved v_t denoising comparison to denoising_vt_comparison.png")
|
||||
plt.close(fig_vt)
|
||||
|
||||
fig_x1t.tight_layout()
|
||||
fig_x1t.savefig("denoising_x1t_comparison.png", dpi=150)
|
||||
logger.info("Saved x1_t predicted state & error comparison to denoising_x1t_comparison.png")
|
||||
plt.close(fig_x1t)
|
||||
|
||||
# Create side-by-side comparison: No RTC (left) vs RTC (right)
|
||||
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig.suptitle("Final Action Comparison: No RTC (left) vs RTC (right)", fontsize=16)
|
||||
|
||||
# Plot on left column (No RTC)
|
||||
self._plot_actions(
|
||||
axs[:, 0],
|
||||
prev_chunk_left_over[0].cpu().numpy(),
|
||||
no_rtc_actions[0].cpu().numpy(),
|
||||
"No RTC",
|
||||
)
|
||||
|
||||
# Plot on right column (RTC)
|
||||
self._plot_actions(
|
||||
axs[:, 1],
|
||||
prev_chunk_left_over[0].cpu().numpy(),
|
||||
rtc_actions[0].detach().cpu().numpy(),
|
||||
"RTC",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("final_actions_comparison.png", dpi=150)
|
||||
logger.info("Saved final actions comparison to final_actions_comparison.png")
|
||||
plt.close(fig)
|
||||
|
||||
# Visualize debug information if enabled
|
||||
if self.cfg.enable_debug_viz and self.policy.rtc_processor is not None:
|
||||
self._visualize_debug_info()
|
||||
|
||||
logger.info("Evaluation completed successfully")
|
||||
|
||||
def _plot_actions(self, axs, prev_chunk, predicted_actions, title):
|
||||
"""Plot actions comparison on given axes."""
|
||||
# Ensure arrays are 2D
|
||||
if prev_chunk.ndim == 1:
|
||||
prev_chunk = prev_chunk.reshape(1, -1)
|
||||
if predicted_actions.ndim == 1:
|
||||
predicted_actions = predicted_actions.reshape(1, -1)
|
||||
|
||||
for j in range(min(prev_chunk.shape[-1], 6)): # Limit to 6 dimensions
|
||||
axs[j].plot(
|
||||
np.arange(prev_chunk.shape[0]),
|
||||
prev_chunk[:, j],
|
||||
color="green",
|
||||
label="Previous Chunk",
|
||||
)
|
||||
axs[j].plot(
|
||||
np.arange(predicted_actions.shape[0]),
|
||||
predicted_actions[:, j],
|
||||
color="red" if "RTC" in title else "blue",
|
||||
label=title,
|
||||
)
|
||||
axs[j].set_ylabel("Joint angle", fontsize=14)
|
||||
axs[j].grid()
|
||||
axs[j].legend(loc="upper right", fontsize=14)
|
||||
axs[j].set_title(title if j == 0 else "", fontsize=12)
|
||||
if j == 2:
|
||||
axs[j].set_xlabel("Step #", fontsize=16)
|
||||
|
||||
def _visualize_debug_info(self):
|
||||
"""Visualize debug information from the RTC processor."""
|
||||
import os
|
||||
|
||||
# Use proxy method to check if debug is enabled
|
||||
if not self.policy.rtc_processor.is_debug_enabled():
|
||||
logger.warning("Debug tracking is disabled. Skipping debug visualization.")
|
||||
return
|
||||
|
||||
# Get tracker length using proxy method
|
||||
if self.policy.rtc_processor.get_tracker_length() == 0:
|
||||
logger.warning("No debug steps recorded. Skipping debug visualization.")
|
||||
return
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||
logger.info(f"Saving debug visualizations to {self.cfg.output_dir}")
|
||||
|
||||
# Still need direct access to tracker for visualization functions
|
||||
# This is acceptable since RTCDebugVisualizer is part of the RTC package
|
||||
tracker = self.policy.rtc_processor.tracker
|
||||
|
||||
# Print statistics
|
||||
RTCDebugVisualizer.print_debug_statistics(tracker)
|
||||
|
||||
# Plot debug summary
|
||||
summary_path = os.path.join(self.cfg.output_dir, "debug_summary.png")
|
||||
RTCDebugVisualizer.plot_debug_summary(
|
||||
tracker,
|
||||
save_path=summary_path,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# Plot correction heatmap
|
||||
heatmap_path = os.path.join(self.cfg.output_dir, "correction_heatmap.png")
|
||||
RTCDebugVisualizer.plot_correction_heatmap(
|
||||
tracker,
|
||||
save_path=heatmap_path,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# Plot step-by-step comparison (last step)
|
||||
step_path = os.path.join(self.cfg.output_dir, "step_comparison_last.png")
|
||||
RTCDebugVisualizer.plot_step_by_step_comparison(
|
||||
tracker,
|
||||
step_idx=-1,
|
||||
save_path=step_path,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# Plot step-by-step comparison (first step)
|
||||
step_path_first = os.path.join(self.cfg.output_dir, "step_comparison_first.png")
|
||||
if self.policy.rtc_processor.get_tracker_length() > 0:
|
||||
RTCDebugVisualizer.plot_step_by_step_comparison(
|
||||
tracker,
|
||||
step_idx=0,
|
||||
save_path=step_path_first,
|
||||
show=False,
|
||||
)
|
||||
|
||||
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RTCEvalConfig):
|
||||
"""Main entry point for RTC evaluation."""
|
||||
# Set random seed for reproducibility
|
||||
set_seed(cfg.seed)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("RTC Dataset Evaluation")
|
||||
logger.info(f"Config: {cfg}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
evaluator = RTCEvaluator(cfg)
|
||||
evaluator.run_evaluation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,874 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot/environment and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot/environment executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
Usage:
|
||||
# With real robot
|
||||
python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100
|
||||
|
||||
# With simulation environment
|
||||
python rtc_demo.py --policy.path=lerobot/smolvla_base --env.type=pusht
|
||||
|
||||
# With config file
|
||||
python rtc_demo.py --config_path=path/to/config.json
|
||||
|
||||
# With policy compilation for faster inference (recommended for production)
|
||||
python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 --compile_policy=true
|
||||
|
||||
# With aggressive compilation for maximum speed
|
||||
python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 --compile_policy=true --compile_mode=max-autotune
|
||||
|
||||
Performance Notes:
|
||||
- torch.compile() is NOT supported on MPS (Apple Silicon) due to attention operation limitations
|
||||
- For MPS optimization, reduce num_steps in the policy config (biggest speedup)
|
||||
- CUDA devices will see 2-5x speedup with compilation enabled
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.envs.configs import EnvConfig # noqa: F401
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tensor_stats_str(tensor: Tensor | None, name: str = "tensor") -> str:
|
||||
"""Generate readable statistics string for a tensor."""
|
||||
if tensor is None:
|
||||
return f"{name}: None"
|
||||
|
||||
stats = (
|
||||
f"{name}:\n"
|
||||
f" shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}\n"
|
||||
f" min={tensor.min().item():.6f}, max={tensor.max().item():.6f}\n"
|
||||
f" mean={tensor.mean().item():.6f}, std={tensor.std().item():.6f}"
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
def compare_tensors(tensor1: Tensor, tensor2: Tensor, name1: str = "tensor1", name2: str = "tensor2") -> str:
|
||||
"""Compare two tensors and return detailed difference statistics."""
|
||||
if tensor1 is None or tensor2 is None:
|
||||
return f"Cannot compare: {name1}={tensor1 is not None}, {name2}={tensor2 is not None}"
|
||||
|
||||
# Ensure same shape for comparison
|
||||
if tensor1.shape != tensor2.shape:
|
||||
return f"Shape mismatch: {name1}={tuple(tensor1.shape)} vs {name2}={tuple(tensor2.shape)}"
|
||||
|
||||
diff = tensor1 - tensor2
|
||||
abs_diff = torch.abs(diff)
|
||||
|
||||
# Per-timestep statistics
|
||||
if len(diff.shape) >= 2:
|
||||
# Shape is (batch, time, action_dim) or (time, action_dim)
|
||||
per_timestep_mean = abs_diff.mean(dim=-1) # Average across action dimensions
|
||||
|
||||
timestep_stats = "\n Per-timestep abs diff (averaged across action dims):\n"
|
||||
if len(per_timestep_mean.shape) > 1:
|
||||
# Has batch dimension
|
||||
for batch_idx in range(per_timestep_mean.shape[0]):
|
||||
timestep_stats += f" Batch {batch_idx}: ["
|
||||
for t in range(min(10, per_timestep_mean.shape[1])): # Show first 10 timesteps
|
||||
timestep_stats += f"{per_timestep_mean[batch_idx, t].item():.6f}, "
|
||||
if per_timestep_mean.shape[1] > 10:
|
||||
timestep_stats += "..."
|
||||
timestep_stats += "]\n"
|
||||
else:
|
||||
timestep_stats += " ["
|
||||
for t in range(min(10, len(per_timestep_mean))):
|
||||
timestep_stats += f"{per_timestep_mean[t].item():.6f}, "
|
||||
if len(per_timestep_mean) > 10:
|
||||
timestep_stats += "..."
|
||||
timestep_stats += "]\n"
|
||||
else:
|
||||
timestep_stats = ""
|
||||
|
||||
result = (
|
||||
f"\nDifference: {name1} - {name2}:\n"
|
||||
f" abs_diff: min={abs_diff.min().item():.6f}, max={abs_diff.max().item():.6f}\n"
|
||||
f" abs_diff: mean={abs_diff.mean().item():.6f}, std={abs_diff.std().item():.6f}\n"
|
||||
f" relative_diff: mean={abs_diff.mean().item() / (torch.abs(tensor2).mean().item() + 1e-8) * 100:.2f}%"
|
||||
f"{timestep_stats}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
class EnvWrapper:
|
||||
"""Wrapper for gym environments to provide same interface as RobotWrapper."""
|
||||
|
||||
def __init__(self, env, env_cfg: EnvConfig):
|
||||
self.env = env
|
||||
self.env_cfg = env_cfg
|
||||
self.lock = Lock()
|
||||
self._last_obs = None
|
||||
self._episode_count = 0
|
||||
self._step_count = 0
|
||||
|
||||
# Initialize environment
|
||||
obs, _ = self.env.reset()
|
||||
self._last_obs = (
|
||||
obs[0]
|
||||
if isinstance(obs, tuple)
|
||||
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
||||
else obs
|
||||
)
|
||||
|
||||
# Cache feature names
|
||||
self._observation_features = None
|
||||
self._action_features = None
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
"""Get current observation from environment.
|
||||
|
||||
Returns observations in the same format as robot.get_observation():
|
||||
a dict mapping feature names to numpy arrays.
|
||||
"""
|
||||
with self.lock:
|
||||
if self._last_obs is None:
|
||||
# Reset environment on first observation
|
||||
obs, _ = self.env.reset()
|
||||
self._last_obs = (
|
||||
obs[0]
|
||||
if isinstance(obs, tuple)
|
||||
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
||||
else obs
|
||||
)
|
||||
|
||||
# VectorEnv returns observations as numpy arrays in a batch
|
||||
# Extract first element if it's a vectorized observation
|
||||
obs = self._last_obs
|
||||
if isinstance(obs, dict):
|
||||
# Handle dict observations (extract first element from batch if needed)
|
||||
result = {}
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, np.ndarray) and len(value.shape) > 0 and value.shape[0] == 1:
|
||||
# Remove batch dimension for single env
|
||||
result[key] = value[0]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
else:
|
||||
# Handle array observations - shouldn't happen with our configs but handle it
|
||||
return {"observation": obs[0] if len(obs.shape) > 1 else obs}
|
||||
|
||||
def send_action(self, action: dict):
|
||||
"""Execute action in environment and update observation."""
|
||||
with self.lock:
|
||||
# Convert action dict to array based on action_features
|
||||
action_list = []
|
||||
for feature_name in self.action_features():
|
||||
if feature_name in action:
|
||||
action_list.append(action[feature_name])
|
||||
|
||||
action_array = np.array(action_list)
|
||||
|
||||
# VectorEnv expects actions with batch dimension
|
||||
action_batch = action_array.reshape(1, -1)
|
||||
|
||||
# Step environment
|
||||
obs, _reward, terminated, truncated, _info = self.env.step(action_batch)
|
||||
|
||||
# Extract from batch
|
||||
self._last_obs = (
|
||||
obs[0]
|
||||
if isinstance(obs, tuple)
|
||||
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
||||
else obs
|
||||
)
|
||||
self._step_count += 1
|
||||
|
||||
# Check if episode is done (handle vectorized env format)
|
||||
is_done = terminated[0] if isinstance(terminated, (np.ndarray, list)) else terminated
|
||||
is_truncated = truncated[0] if isinstance(truncated, (np.ndarray, list)) else truncated
|
||||
|
||||
# Reset if episode is done
|
||||
if is_done or is_truncated:
|
||||
logger.info(f"Episode {self._episode_count} finished after {self._step_count} steps")
|
||||
obs, _ = self.env.reset()
|
||||
self._last_obs = (
|
||||
obs[0]
|
||||
if isinstance(obs, tuple)
|
||||
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
||||
else obs
|
||||
)
|
||||
self._episode_count += 1
|
||||
self._step_count = 0
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
"""Get observation feature names from environment config."""
|
||||
if self._observation_features is not None:
|
||||
return self._observation_features
|
||||
|
||||
with self.lock:
|
||||
features = []
|
||||
for feature_name in self.env_cfg.features:
|
||||
if feature_name != "action":
|
||||
# Use the mapped name from features_map
|
||||
mapped_name = self.env_cfg.features_map.get(feature_name, feature_name)
|
||||
features.append(mapped_name)
|
||||
|
||||
self._observation_features = features
|
||||
return features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
"""Get action feature names from environment config."""
|
||||
if self._action_features is not None:
|
||||
return self._action_features
|
||||
|
||||
with self.lock:
|
||||
# Return action dimension names
|
||||
action_dim = self.env_cfg.features["action"].shape[0]
|
||||
self._action_features = [f"action_{i}" for i in range(action_dim)]
|
||||
return self._action_features
|
||||
|
||||
|
||||
class ActionQueue:
|
||||
def __init__(self, cfg: RTCConfig):
|
||||
self.queue = None # Processed actions for robot rollout
|
||||
self.original_queue = None # Original actions for RTC
|
||||
self.lock = Lock()
|
||||
self.last_index = 0
|
||||
self.cfg = cfg
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
with self.lock:
|
||||
if self.queue is None or self.last_index >= len(self.queue):
|
||||
return None
|
||||
|
||||
action = self.queue[self.last_index]
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def qsize(self) -> int:
|
||||
# with self.lock:
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
# with self.lock:
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index + 1 <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
# with self.lock:
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor:
|
||||
"""Get left over ORIGINAL actions for RTC prev_chunk_left_over."""
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
):
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
|
||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
|
||||
logger.info(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.info(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.info(f"real_delay: {real_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||
if self.queue is None:
|
||||
self.original_queue = original_actions.clone()
|
||||
self.queue = processed_actions.clone()
|
||||
return
|
||||
|
||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||
self.original_queue = self.original_queue[self.last_index :]
|
||||
|
||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||
self.queue = self.queue[self.last_index :]
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as dealy calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration (mutually exclusive with env)
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# Environment configuration (mutually exclusive with robot)
|
||||
env: EnvConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Compilation options
|
||||
compile_policy: bool = (
|
||||
False # Compile policy with torch.compile() for faster inference (not supported on MPS)
|
||||
)
|
||||
compile_mode: str = "default" # Compilation mode: default, reduce-overhead, max-autotune
|
||||
|
||||
# Alternative optimization options (work on all devices including MPS)
|
||||
use_channels_last: bool = False # Use channels_last memory format for images (faster on some devices)
|
||||
enable_cudnn_benchmark: bool = True # Enable cuDNN benchmarking (CUDA only)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Debug options
|
||||
verbose_rtc_comparison: bool = True # Enable detailed RTC comparison output
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that either robot or env is provided, but not both
|
||||
if self.robot is None and self.env is None:
|
||||
raise ValueError("Either robot or env configuration must be provided")
|
||||
if self.robot is not None and self.env is not None:
|
||||
raise ValueError("Cannot specify both robot and env configuration. Choose one.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
# for k, v in obs_with_policy_features.items():
|
||||
# if isinstance(v, np.ndarray):
|
||||
# obs_with_policy_features[k] = torch.from_numpy(v).to(policy_device)
|
||||
|
||||
# if is_image_key(k):
|
||||
# obs_with_policy_features[k] = obs_with_policy_features[k].type(torch.float32) / 255
|
||||
# obs_with_policy_features[k] = obs_with_policy_features[k].permute(2, 0, 1).unsqueeze(0)
|
||||
# elif isinstance(obs_with_policy_features[k], torch.Tensor):
|
||||
# obs_with_policy_features[k] = obs_with_policy_features[k].unsqueeze(0)
|
||||
|
||||
obs_with_policy_features["task"] = cfg.task
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
noise_size = (1, policy.config.chunk_size, policy.config.max_action_dim)
|
||||
noise = policy.model.sample_noise(noise_size, policy_device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Generate actions WITHOUT RTC for comparison (if verbose mode enabled)
|
||||
if cfg.verbose_rtc_comparison:
|
||||
policy.config.rtc_config.enabled = False
|
||||
not_rtc_actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
noise=noise,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
noise=noise_clone if cfg.verbose_rtc_comparison else noise,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
# Detailed comparison output (if verbose mode enabled)
|
||||
if cfg.verbose_rtc_comparison:
|
||||
logger.info("=" * 80)
|
||||
logger.info("RTC ACTION COMPARISON")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Print detailed statistics
|
||||
logger.info("\n" + tensor_stats_str(not_rtc_actions, "not_rtc_actions (without RTC)"))
|
||||
logger.info("\n" + tensor_stats_str(actions, "actions (with RTC)"))
|
||||
logger.info(
|
||||
"\n" + tensor_stats_str(prev_actions, "prev_actions (leftover from previous chunk)")
|
||||
)
|
||||
|
||||
# Compare RTC vs non-RTC actions
|
||||
logger.info(
|
||||
compare_tensors(actions, not_rtc_actions, "actions (RTC)", "not_rtc_actions (no RTC)")
|
||||
)
|
||||
|
||||
to_non_rtc_diff = actions - not_rtc_actions
|
||||
|
||||
print("to_non_rtc_diff", to_non_rtc_diff)
|
||||
if prev_actions is not None:
|
||||
prev_padded = torch.zeros_like(actions)
|
||||
prev_padded[:, : prev_actions.shape[1], :] = prev_actions
|
||||
to_prev_diff = actions - prev_padded
|
||||
print("to_prev_diff", to_prev_diff)
|
||||
print("=" * 80)
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
logger.debug(f"[GET_ACTIONS] new_delay: {new_delay}")
|
||||
logger.debug(f"[GET_ACTIONS] original_actions shape: {original_actions.shape}")
|
||||
logger.debug(f"[GET_ACTIONS] postprocessed_actions shape: {postprocessed_actions.shape}")
|
||||
logger.debug(f"[GET_ACTIONS] action_index_before_inference: {action_index_before_inference}")
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action = robot_action_processor((action, None))
|
||||
robot.send_action(action)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep((action_interval - dt_s) - 0.001)
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def stop_by_duration(shutdown_event: Event, cfg: RTCDemoConfig):
|
||||
"""Stop the demo by duration."""
|
||||
time.sleep(cfg.duration)
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
vec_env = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor(verbose=cfg.verbose_rtc_comparison)
|
||||
|
||||
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply memory format optimizations
|
||||
if cfg.use_channels_last:
|
||||
logger.info("Converting model to channels_last memory format")
|
||||
try:
|
||||
# Convert vision encoder to channels_last for better performance
|
||||
if hasattr(policy, "vision_encoder"):
|
||||
policy.vision_encoder = policy.vision_encoder.to(memory_format=torch.channels_last)
|
||||
logger.info("Successfully converted to channels_last format")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert to channels_last: {e}")
|
||||
|
||||
# Enable cuDNN benchmarking for CUDA
|
||||
if cfg.enable_cudnn_benchmark and cfg.device == "cuda":
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info("Enabled cuDNN benchmarking")
|
||||
|
||||
# Compile policy if requested
|
||||
if cfg.compile_policy:
|
||||
# Check if device is MPS - torch.compile has issues with MPS backend
|
||||
if cfg.device == "mps":
|
||||
logger.warning("torch.compile() is not stable with MPS backend (Apple Silicon)")
|
||||
logger.warning("Skipping compilation. For better performance on MPS:")
|
||||
logger.warning(" 1. Use torch.float32 instead of bfloat16")
|
||||
logger.warning(" 2. Ensure model uses contiguous memory layouts")
|
||||
logger.warning(" 3. Consider using CUDA if available")
|
||||
else:
|
||||
logger.info(f"Compiling policy with mode: {cfg.compile_mode}")
|
||||
logger.info("First inference will be slower due to compilation, subsequent calls will be faster")
|
||||
|
||||
try:
|
||||
# Compile the predict_action_chunk method
|
||||
policy.predict_action_chunk = torch.compile(
|
||||
policy.predict_action_chunk,
|
||||
mode=cfg.compile_mode,
|
||||
fullgraph=False, # Allow graph breaks for flexibility
|
||||
backend="inductor", # Use inductor backend
|
||||
)
|
||||
logger.info("Policy compiled successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compile policy: {e}")
|
||||
logger.warning("Continuing without compilation")
|
||||
|
||||
# Create robot or environment
|
||||
if cfg.robot is not None:
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
agent_wrapper = RobotWrapper(robot)
|
||||
else:
|
||||
logger.info(f"Initializing environment: {cfg.env.type}")
|
||||
# Create environment using make_env
|
||||
env_dict = make_env(cfg.env, n_envs=1, use_async_envs=False)
|
||||
|
||||
# Validate environment structure: should have exactly one suite
|
||||
if len(env_dict) != 1:
|
||||
raise ValueError(
|
||||
f"Expected exactly one environment suite, but got {len(env_dict)}. "
|
||||
f"Suites: {list(env_dict.keys())}"
|
||||
)
|
||||
|
||||
# Extract the actual env from the dict structure {suite: {task_id: vec_env}}
|
||||
suite_name = list(env_dict.keys())[0]
|
||||
task_dict = env_dict[suite_name]
|
||||
|
||||
# Validate task structure: should have exactly one task
|
||||
if len(task_dict) != 1:
|
||||
raise ValueError(
|
||||
f"Expected exactly one task in suite '{suite_name}', but got {len(task_dict)}. "
|
||||
f"Tasks: {list(task_dict.keys())}"
|
||||
)
|
||||
|
||||
vec_env = task_dict[0]
|
||||
logger.info(f"Created environment: suite='{suite_name}', task_id=0, num_envs={vec_env.num_envs}")
|
||||
|
||||
# Validate that we have exactly 1 parallel environment
|
||||
if vec_env.num_envs != 1:
|
||||
raise ValueError(
|
||||
f"Expected exactly 1 parallel environment, but got {vec_env.num_envs}. "
|
||||
f"The EnvWrapper is designed for single environment instances."
|
||||
)
|
||||
|
||||
agent_wrapper = EnvWrapper(vec_env, cfg.env)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, agent_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(agent_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot or environment
|
||||
if cfg.robot is not None:
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
else:
|
||||
# Close environment
|
||||
if vec_env:
|
||||
vec_env.close()
|
||||
logger.info("Environment closed")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
Executable
+75
@@ -0,0 +1,75 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run RTC evaluation on dataset
|
||||
# This shows different usage scenarios
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
POLICY_PATH="lerobot/smolvla_base"
|
||||
DATASET="lerobot/pusht"
|
||||
DEVICE="cuda" # Change to "cpu" or "mps" if needed
|
||||
|
||||
echo "========================================"
|
||||
echo "RTC Dataset Evaluation Examples"
|
||||
echo "========================================"
|
||||
|
||||
# Example 1: Quick evaluation (100 samples, every step)
|
||||
echo -e "\n[Example 1] Quick evaluation - 100 samples, every step"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=100 \
|
||||
--skip_steps=1 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_quick.json"
|
||||
|
||||
# Example 2: Simulating realistic inference delay (every 3rd step)
|
||||
echo -e "\n[Example 2] Realistic inference delay - 200 samples, every 3rd step"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=200 \
|
||||
--skip_steps=3 \
|
||||
--rtc.execution_horizon=10 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_delay3.json"
|
||||
|
||||
# Example 3: Higher inference delay (every 5th step)
|
||||
echo -e "\n[Example 3] High inference delay - 200 samples, every 5th step"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=200 \
|
||||
--skip_steps=5 \
|
||||
--rtc.execution_horizon=12 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_delay5.json"
|
||||
|
||||
# Example 4: Testing different RTC configurations
|
||||
echo -e "\n[Example 4] Different RTC config - LINEAR schedule"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=100 \
|
||||
--skip_steps=3 \
|
||||
--rtc.execution_horizon=8 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_linear.json"
|
||||
|
||||
# Example 5: Verbose mode for debugging
|
||||
echo -e "\n[Example 5] Verbose mode - 20 samples with detailed output"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=20 \
|
||||
--skip_steps=3 \
|
||||
--device="${DEVICE}" \
|
||||
--verbose=true \
|
||||
--output_path="results/rtc_eval_verbose.json"
|
||||
|
||||
echo -e "\n========================================"
|
||||
echo "All evaluations completed!"
|
||||
echo "Results saved in results/ directory"
|
||||
echo "========================================"
|
||||
Reference in New Issue
Block a user