Fix/quantiles script (#2064)

* refactor augment stats with quantiles script
add parallelization for faster processing
shift the quantile normalization between -1 1

* fix replay buffer tests

* fix comment
This commit is contained in:
Michel Aractingi
2025-09-28 17:58:40 +02:00
committed by GitHub
parent 0e8f01b331
commit 57c6469b1f
3 changed files with 62 additions and 58 deletions
+2 -15
View File
@@ -315,9 +315,8 @@ def _reshape_for_global_stats(
if keepdims:
target_shape = tuple(1 for _ in original_shape)
return value.reshape(target_shape)
elif not keepdims and value.ndim > 0 and value.size == 1:
return value.item()
return value
# Keep at least 1-D arrays to satisfy validator
return np.atleast_1d(value)
def _reshape_single_stat(
@@ -410,12 +409,6 @@ def _compute_basic_stats(
"count": np.array([sample_count]),
}
# For single-element arrays with shape (1,1), convert to scalar arrays
if array.shape == (1, 1):
for key in stats:
if key != "count" and stats[key].size == 1:
stats[key] = np.array(stats[key].item())
for q in quantile_list_keys:
stats[q] = stats["mean"].copy()
@@ -470,12 +463,6 @@ def get_feature_stats(
stats = running_stats.get_statistics()
stats["count"] = np.array([sample_count])
# For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays
if axis is None and reshaped.shape[1] == 1:
for key in stats:
if key != "count" and stats[key].size == 1:
stats[key] = np.array(stats[key].item())
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
return stats
@@ -34,12 +34,15 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
"""
import argparse
import concurrent.futures
import logging
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import write_stats
from lerobot.utils.utils import init_logging
@@ -67,52 +70,44 @@ def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[s
return False
def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict:
"""Load episode data by accessing the underlying HuggingFace dataset.
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
"""Process a single episode and return its statistics.
Args:
dataset: The LeRobot dataset
episode_idx: Index of the episode to load
episode_idx: Index of the episode to process
Returns:
Dictionary containing episode data for each feature
Dictionary containing episode statistics
"""
logging.info(f"Computing stats for episode {episode_idx}")
episode_info = dataset.meta.episodes[episode_idx]
episode_length = episode_info["length"]
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx))
end_idx = start_idx + episode_length
episode_data = {}
episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx))
for key, feature_info in dataset.features.items():
if feature_info["dtype"] == "string":
ep_stats = {}
for key, data in dataset.hf_dataset[start_idx:end_idx].items():
if dataset.features[key]["dtype"] == "string":
continue
if feature_info["dtype"] in ["image", "video"]:
image_paths = []
for row in episode_slice:
if key in row:
relative_path = row[key]
if isinstance(relative_path, str):
absolute_path = str(dataset.meta.root / relative_path)
image_paths.append(absolute_path)
if image_paths:
episode_data[key] = image_paths
data = torch.stack(data).cpu().numpy()
if dataset.features[key]["dtype"] in ["image", "video"]:
axes_to_reduce = (0, 2, 3)
keepdims = True
else:
arrays = []
for row in episode_slice:
if key in row:
arrays.append(np.array(row[key]))
axes_to_reduce = 0
keepdims = data.ndim == 1
if arrays:
episode_data[key] = np.stack(arrays)
ep_stats[key] = get_feature_stats(
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
)
return episode_data
if dataset.features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
@@ -127,11 +122,25 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
episode_stats_list = []
max_workers = min(dataset.num_episodes, 8)
for episode_idx in range(dataset.num_episodes):
episode_data = load_episode_data(dataset, episode_idx)
ep_stats = compute_episode_stats(episode_data, dataset.features)
episode_stats_list.append(ep_stats)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_episode = {
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
for episode_idx in range(dataset.num_episodes)
}
episode_results = {}
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
for future in concurrent.futures.as_completed(future_to_episode):
episode_idx = future_to_episode[future]
ep_stats = future.result()
episode_results[episode_idx] = ep_stats
pbar.update(1)
for episode_idx in range(dataset.num_episodes):
if episode_idx in episode_results:
episode_stats_list.append(episode_results[episode_idx])
if not episode_stats_list:
raise ValueError("No episode data found for computing statistics")
@@ -143,12 +152,14 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
def augment_dataset_with_quantile_stats(
repo_id: str,
root: str | Path | None = None,
overwrite: bool = False,
) -> None:
"""Augment a dataset with quantile statistics if they are missing.
Args:
repo_id: Repository ID of the dataset
root: Local root directory for the dataset
overwrite: Overwrite existing quantile statistics if they already exist
"""
logging.info(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(
@@ -156,7 +167,7 @@ def augment_dataset_with_quantile_stats(
root=root,
)
if has_quantile_stats(dataset.meta.stats):
if not overwrite and has_quantile_stats(dataset.meta.stats):
logging.info("Dataset already contains quantile statistics. No action needed.")
return
@@ -189,6 +200,11 @@ def main():
type=str,
help="Local root directory for the dataset",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing quantile statistics if they already exist",
)
args = parser.parse_args()
root = Path(args.root) if args.root else None
@@ -198,6 +214,7 @@ def main():
augment_dataset_with_quantile_stats(
repo_id=args.repo_id,
root=root,
overwrite=args.overwrite,
)
+4 -4
View File
@@ -288,8 +288,8 @@ class _NormalizationMixin:
Normalization Modes:
- MEAN_STD: Centers data around zero with unit variance.
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
- QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99).
- QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90).
- QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99).
- QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90).
Args:
tensor: The input tensor to transform.
@@ -375,7 +375,7 @@ class _NormalizationMixin:
)
if inverse:
return tensor * denom + q01
return (tensor - q01) / denom
return 2.0 * (tensor - q01) / denom - 1.0
if norm_mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10", None)
@@ -392,7 +392,7 @@ class _NormalizationMixin:
)
if inverse:
return tensor * denom + q10
return (tensor - q10) / denom
return 2.0 * (tensor - q10) / denom - 1.0
# If necessary stats are missing, return input unchanged.
return tensor