mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
443 lines
15 KiB
Python
443 lines
15 KiB
Python
#!/usr/bin/env python
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Benchmark action tokenization: reconstruction error, compression ratio, and timing.
|
|
|
|
Loads action chunks from a LeRobot dataset, encodes/decodes them with a trained action
|
|
tokenizer, and reports:
|
|
- Reconstruction: MAE, MSE, RMSE, max absolute error, per-dimension MAE
|
|
- Jerk: mean absolute jerk (original and reconstructed), jerk reconstruction MAE
|
|
- Compression: ratio (input size / mean tokens), token length stats
|
|
- Timing: mean encode/decode time per chunk
|
|
|
|
Results are saved to outputs/action_tokenizer_benchmark/<timestamp>_results.json.
|
|
|
|
Example:
|
|
|
|
```bash
|
|
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
|
--action-tokenizer-path=outputs/wavetoken \
|
|
--repo-id=lerobot/pusht \
|
|
--action-horizon=10 \
|
|
--max-episodes=50 \
|
|
--output-dir=outputs/action_tokenizer_benchmark
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from lerobot.configs.types import NormalizationMode
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
|
|
|
# Optional: use same helpers as train script if we want to avoid duplication
|
|
from lerobot.scripts.lerobot_train_tokenizer import (
|
|
apply_normalization,
|
|
process_episode,
|
|
)
|
|
|
|
|
|
def load_action_chunks(
|
|
repo_id: str,
|
|
root: str | None,
|
|
action_horizon: int,
|
|
max_episodes: int | None,
|
|
sample_fraction: float,
|
|
encoded_dims: str,
|
|
delta_dims: str | None,
|
|
use_delta_transform: bool,
|
|
state_key: str,
|
|
normalization_mode: NormalizationMode,
|
|
):
|
|
"""Load and normalize action chunks from a LeRobot dataset (same pipeline as training)."""
|
|
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
|
num_episodes = dataset.num_episodes
|
|
if max_episodes is not None:
|
|
num_episodes = min(max_episodes, num_episodes)
|
|
|
|
# Parse encoded dims
|
|
encoded_dim_ranges = []
|
|
for range_str in encoded_dims.split(","):
|
|
start, end = map(int, range_str.strip().split(":"))
|
|
encoded_dim_ranges.append((start, end))
|
|
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
|
|
|
delta_dim_list = None
|
|
if delta_dims is not None and delta_dims.strip():
|
|
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
|
|
|
|
all_chunks = []
|
|
for ep_idx in range(num_episodes):
|
|
chunks = process_episode(
|
|
(
|
|
dataset,
|
|
ep_idx,
|
|
action_horizon,
|
|
delta_dim_list,
|
|
sample_fraction,
|
|
state_key,
|
|
use_delta_transform,
|
|
)
|
|
)
|
|
if chunks is not None:
|
|
all_chunks.append(chunks)
|
|
|
|
if not all_chunks:
|
|
raise ValueError("No action chunks collected. Check action_horizon and dataset.")
|
|
|
|
all_chunks = np.concatenate(all_chunks, axis=0)
|
|
|
|
# Extract encoded dimensions only
|
|
encoded_chunks = []
|
|
for start, end in encoded_dim_ranges:
|
|
encoded_chunks.append(all_chunks[:, :, start:end])
|
|
encoded_chunks = np.concatenate(encoded_chunks, axis=-1)
|
|
|
|
# Normalize
|
|
norm_stats = dataset.meta.stats
|
|
if norm_stats is not None and ACTION in norm_stats:
|
|
action_stats = norm_stats[ACTION]
|
|
encoded_dim_indices = []
|
|
for start, end in encoded_dim_ranges:
|
|
encoded_dim_indices.extend(range(start, end))
|
|
encoded_dim_indices = np.array(encoded_dim_indices)
|
|
encoded_stats = {}
|
|
for stat_name, stat_values in action_stats.items():
|
|
if isinstance(stat_values, (list, np.ndarray)):
|
|
stat_array = np.array(stat_values)
|
|
if len(stat_array) > max(encoded_dim_indices):
|
|
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
|
|
if encoded_stats:
|
|
try:
|
|
encoded_chunks = apply_normalization(
|
|
encoded_chunks, encoded_stats, normalization_mode, eps=1e-8
|
|
)
|
|
except ValueError:
|
|
pass
|
|
|
|
return encoded_chunks, total_encoded_dims, action_horizon, dataset.repo_id
|
|
|
|
|
|
def compute_reconstruction_metrics(original: np.ndarray, reconstructed: np.ndarray):
|
|
"""Compute reconstruction error metrics (original and reconstructed same shape [N, T, D])."""
|
|
diff = reconstructed - original
|
|
mae = float(np.mean(np.abs(diff)))
|
|
mse = float(np.mean(diff**2))
|
|
rmse = float(np.sqrt(mse))
|
|
max_abs_err = float(np.max(np.abs(diff)))
|
|
|
|
# Per-dimension MAE (over N and T)
|
|
per_dim_mae = np.mean(np.abs(diff), axis=(0, 1))
|
|
per_dim_mae = per_dim_mae.tolist()
|
|
|
|
return {
|
|
"reconstruction_mae": mae,
|
|
"reconstruction_mse": mse,
|
|
"reconstruction_rmse": rmse,
|
|
"reconstruction_max_abs_error": max_abs_err,
|
|
"per_dimension_mae": per_dim_mae,
|
|
}
|
|
|
|
|
|
def compute_jerk_metrics(original: np.ndarray, reconstructed: np.ndarray) -> dict:
|
|
"""Compute jerk (3rd derivative of action w.r.t. time) metrics.
|
|
|
|
Args:
|
|
original: Action chunks [N, T, D].
|
|
reconstructed: Reconstructed action chunks [N, T, D].
|
|
|
|
Returns:
|
|
Dict with mean absolute jerk for original, reconstructed, and jerk reconstruction MAE.
|
|
"""
|
|
# Jerk = 3rd discrete difference along time axis; need T >= 4
|
|
if original.shape[1] < 4:
|
|
return {}
|
|
jerk_orig = np.diff(original, n=3, axis=1) # (N, T-3, D)
|
|
jerk_recon = np.diff(reconstructed, n=3, axis=1)
|
|
mae_jerk_orig = float(np.mean(np.abs(jerk_orig)))
|
|
mae_jerk_recon = float(np.mean(np.abs(jerk_recon)))
|
|
jerk_reconstruction_mae = float(np.mean(np.abs(jerk_recon - jerk_orig)))
|
|
return {
|
|
"jerk_mae_original": mae_jerk_orig,
|
|
"jerk_mae_reconstructed": mae_jerk_recon,
|
|
"jerk_reconstruction_mae": jerk_reconstruction_mae,
|
|
}
|
|
|
|
|
|
def run_benchmark(
|
|
action_chunks: np.ndarray,
|
|
action_horizon: int,
|
|
action_dim: int,
|
|
tokenizer_path: str,
|
|
max_chunks_for_reconstruction: int | None = 500,
|
|
):
|
|
"""Encode/decode action chunks and compute metrics."""
|
|
from transformers import AutoProcessor
|
|
|
|
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
|
|
|
|
n_chunks = len(action_chunks)
|
|
sample_size = n_chunks
|
|
if max_chunks_for_reconstruction is not None:
|
|
sample_size = min(max_chunks_for_reconstruction, n_chunks)
|
|
rng = np.random.RandomState(42)
|
|
indices = rng.choice(n_chunks, size=sample_size, replace=False)
|
|
sample_chunks = action_chunks[indices]
|
|
|
|
# Encode
|
|
token_lengths = []
|
|
encode_times = []
|
|
all_tokens = []
|
|
for i in range(len(sample_chunks)):
|
|
chunk = sample_chunks[i : i + 1]
|
|
t0 = time.perf_counter()
|
|
tokens = processor(chunk)[0]
|
|
encode_times.append(time.perf_counter() - t0)
|
|
if isinstance(tokens, list):
|
|
token_lengths.append(len(tokens))
|
|
all_tokens.append(tokens)
|
|
else:
|
|
n = tokens.shape[0] if hasattr(tokens, "shape") else len(tokens)
|
|
token_lengths.append(n)
|
|
all_tokens.append(tokens.tolist() if hasattr(tokens, "tolist") else list(tokens))
|
|
|
|
# Decode (processor keeps time_horizon/action_dim from encode)
|
|
decoded_list = []
|
|
decode_times = []
|
|
for i, tok_list in enumerate(all_tokens):
|
|
t0 = time.perf_counter()
|
|
recon = processor.decode(
|
|
[tok_list],
|
|
time_horizon=action_horizon,
|
|
action_dim=action_dim,
|
|
)
|
|
decode_times.append(time.perf_counter() - t0)
|
|
decoded_list.append(recon)
|
|
decoded = np.concatenate(decoded_list, axis=0)
|
|
|
|
# Reconstruction metrics
|
|
metrics = compute_reconstruction_metrics(sample_chunks, decoded)
|
|
|
|
# Jerk metrics (3rd derivative along time)
|
|
jerk_metrics = compute_jerk_metrics(sample_chunks, decoded)
|
|
metrics.update(jerk_metrics)
|
|
|
|
# Compression
|
|
token_lengths = np.array(token_lengths)
|
|
input_size = action_horizon * action_dim
|
|
compression_ratio = input_size / float(np.mean(token_lengths))
|
|
metrics["compression_ratio"] = compression_ratio
|
|
metrics["mean_token_length"] = float(np.mean(token_lengths))
|
|
metrics["std_token_length"] = float(np.std(token_lengths))
|
|
metrics["min_token_length"] = int(np.min(token_lengths))
|
|
metrics["max_token_length"] = int(np.max(token_lengths))
|
|
metrics["p50_token_length"] = float(np.percentile(token_lengths, 50))
|
|
metrics["p99_token_length"] = float(np.percentile(token_lengths, 99))
|
|
|
|
# Timing (seconds per chunk)
|
|
metrics["mean_encode_time_sec"] = float(np.mean(encode_times))
|
|
metrics["mean_decode_time_sec"] = float(np.mean(decode_times))
|
|
metrics["num_chunks_evaluated"] = sample_size
|
|
metrics["total_chunks_available"] = n_chunks
|
|
|
|
return metrics
|
|
|
|
|
|
def main(
|
|
action_tokenizer_path: str,
|
|
repo_id: str,
|
|
root: str | None = None,
|
|
action_horizon: int = 10,
|
|
max_episodes: int | None = 100,
|
|
sample_fraction: float = 0.2,
|
|
encoded_dims: str = "0:6",
|
|
delta_dims: str | None = None,
|
|
use_delta_transform: bool = False,
|
|
state_key: str = OBS_STATE,
|
|
normalization_mode: str = "QUANTILES",
|
|
max_chunks_for_reconstruction: int | None = 500,
|
|
output_dir: str | None = None,
|
|
):
|
|
if output_dir is None:
|
|
output_dir = "outputs/action_tokenizer_benchmark"
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
norm_mode = NormalizationMode(normalization_mode)
|
|
except ValueError:
|
|
norm_mode = NormalizationMode.QUANTILES
|
|
|
|
print("Loading action chunks...")
|
|
encoded_chunks, action_dim, horizon, _ = load_action_chunks(
|
|
repo_id=repo_id,
|
|
root=root,
|
|
action_horizon=action_horizon,
|
|
max_episodes=max_episodes,
|
|
sample_fraction=sample_fraction,
|
|
encoded_dims=encoded_dims,
|
|
delta_dims=delta_dims,
|
|
use_delta_transform=use_delta_transform,
|
|
state_key=state_key,
|
|
normalization_mode=norm_mode,
|
|
)
|
|
print(f"Loaded {len(encoded_chunks)} chunks, shape {encoded_chunks.shape} (H={horizon}, D={action_dim})")
|
|
|
|
print("Running tokenizer benchmark...")
|
|
metrics = run_benchmark(
|
|
action_chunks=encoded_chunks,
|
|
action_horizon=horizon,
|
|
action_dim=action_dim,
|
|
tokenizer_path=action_tokenizer_path,
|
|
max_chunks_for_reconstruction=max_chunks_for_reconstruction,
|
|
)
|
|
|
|
# Attach config for reproducibility
|
|
results = {
|
|
"config": {
|
|
"action_tokenizer_path": action_tokenizer_path,
|
|
"repo_id": repo_id,
|
|
"action_horizon": action_horizon,
|
|
"max_episodes": max_episodes,
|
|
"sample_fraction": sample_fraction,
|
|
"encoded_dims": encoded_dims,
|
|
"delta_dims": delta_dims,
|
|
"use_delta_transform": use_delta_transform,
|
|
"state_key": state_key,
|
|
"normalization_mode": normalization_mode,
|
|
},
|
|
"metrics": metrics,
|
|
}
|
|
|
|
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
safe_repo = repo_id.replace("/", "_")
|
|
out_file = output_path / f"{timestamp}_{safe_repo}_action_tokenizer_results.json"
|
|
with open(out_file, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print(f"Results saved to {out_file}")
|
|
print("Metrics:")
|
|
for k, v in metrics.items():
|
|
if isinstance(v, list):
|
|
print(f" {k}: (length {len(v)})")
|
|
else:
|
|
print(f" {k}: {v}")
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Benchmark action tokenization (reconstruction error, compression, timing)."
|
|
)
|
|
parser.add_argument(
|
|
"--action-tokenizer-path",
|
|
type=str,
|
|
required=True,
|
|
help="Path or HuggingFace repo id of the trained action tokenizer (e.g. outputs/wavetoken).",
|
|
)
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="LeRobot dataset repo id (e.g. lerobot/pusht).",
|
|
)
|
|
parser.add_argument(
|
|
"--root",
|
|
type=str,
|
|
default=None,
|
|
help="Root directory for LeRobot datasets.",
|
|
)
|
|
parser.add_argument(
|
|
"--action-horizon",
|
|
type=int,
|
|
default=10,
|
|
help="Number of future steps per action chunk.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-episodes",
|
|
type=int,
|
|
default=None,
|
|
help="Max episodes to use (default: all).",
|
|
)
|
|
parser.add_argument(
|
|
"--sample-fraction",
|
|
type=float,
|
|
default=0.2,
|
|
help="Fraction of chunks to sample per episode.",
|
|
)
|
|
parser.add_argument(
|
|
"--encoded-dims",
|
|
type=str,
|
|
default="0:6",
|
|
help="Dimension ranges to encode (e.g. 0:6,7:14).",
|
|
)
|
|
parser.add_argument(
|
|
"--delta-dims",
|
|
type=str,
|
|
default=None,
|
|
help="Comma-separated dimensions for delta transform.",
|
|
)
|
|
parser.add_argument(
|
|
"--use-delta-transform",
|
|
action="store_true",
|
|
help="Apply delta (relative) transform to specified dimensions.",
|
|
)
|
|
parser.add_argument(
|
|
"--state-key",
|
|
type=str,
|
|
default=OBS_STATE,
|
|
help="Dataset key for state (for delta transform).",
|
|
)
|
|
parser.add_argument(
|
|
"--normalization-mode",
|
|
type=str,
|
|
default="QUANTILES",
|
|
choices=[m.value for m in NormalizationMode],
|
|
help="Normalization mode for actions.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-chunks-for-reconstruction",
|
|
type=int,
|
|
default=500,
|
|
help="Max chunks to use for reconstruction metrics (default: 500).",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default="outputs/action_tokenizer_benchmark",
|
|
help="Directory to save results JSON (default: outputs/action_tokenizer_benchmark).",
|
|
)
|
|
args = parser.parse_args()
|
|
main(
|
|
action_tokenizer_path=args.action_tokenizer_path,
|
|
repo_id=args.repo_id,
|
|
root=args.root,
|
|
action_horizon=args.action_horizon,
|
|
max_episodes=args.max_episodes,
|
|
sample_fraction=args.sample_fraction,
|
|
encoded_dims=args.encoded_dims,
|
|
delta_dims=args.delta_dims,
|
|
use_delta_transform=args.use_delta_transform,
|
|
state_key=args.state_key,
|
|
normalization_mode=args.normalization_mode,
|
|
max_chunks_for_reconstruction=args.max_chunks_for_reconstruction,
|
|
output_dir=args.output_dir,
|
|
)
|