Fix/quantiles script (#2064)

* refactor augment stats with quantiles script
add parallelization for faster processing
shift the quantile normalization between -1 1

* fix replay buffer tests

* fix comment
This commit is contained in:
Michel Aractingi
2025-09-28 17:58:40 +02:00
committed by GitHub
parent 0e8f01b331
commit 57c6469b1f
3 changed files with 62 additions and 58 deletions
+2 -15
View File
@@ -315,9 +315,8 @@ def _reshape_for_global_stats(
if keepdims: if keepdims:
target_shape = tuple(1 for _ in original_shape) target_shape = tuple(1 for _ in original_shape)
return value.reshape(target_shape) return value.reshape(target_shape)
elif not keepdims and value.ndim > 0 and value.size == 1: # Keep at least 1-D arrays to satisfy validator
return value.item() return np.atleast_1d(value)
return value
def _reshape_single_stat( def _reshape_single_stat(
@@ -410,12 +409,6 @@ def _compute_basic_stats(
"count": np.array([sample_count]), "count": np.array([sample_count]),
} }
# For single-element arrays with shape (1,1), convert to scalar arrays
if array.shape == (1, 1):
for key in stats:
if key != "count" and stats[key].size == 1:
stats[key] = np.array(stats[key].item())
for q in quantile_list_keys: for q in quantile_list_keys:
stats[q] = stats["mean"].copy() stats[q] = stats["mean"].copy()
@@ -470,12 +463,6 @@ def get_feature_stats(
stats = running_stats.get_statistics() stats = running_stats.get_statistics()
stats["count"] = np.array([sample_count]) stats["count"] = np.array([sample_count])
# For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays
if axis is None and reshaped.shape[1] == 1:
for key in stats:
if key != "count" and stats[key].size == 1:
stats[key] = np.array(stats[key].item())
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape) stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
return stats return stats
@@ -34,12 +34,15 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
""" """
import argparse import argparse
import concurrent.futures
import logging import logging
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import write_stats from lerobot.datasets.utils import write_stats
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
@@ -67,52 +70,44 @@ def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[s
return False return False
def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict: def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
"""Load episode data by accessing the underlying HuggingFace dataset. """Process a single episode and return its statistics.
Args: Args:
dataset: The LeRobot dataset dataset: The LeRobot dataset
episode_idx: Index of the episode to load episode_idx: Index of the episode to process
Returns: Returns:
Dictionary containing episode data for each feature Dictionary containing episode statistics
""" """
logging.info(f"Computing stats for episode {episode_idx}")
episode_info = dataset.meta.episodes[episode_idx] start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
episode_length = episode_info["length"] end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx)) ep_stats = {}
end_idx = start_idx + episode_length for key, data in dataset.hf_dataset[start_idx:end_idx].items():
if dataset.features[key]["dtype"] == "string":
episode_data = {}
episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx))
for key, feature_info in dataset.features.items():
if feature_info["dtype"] == "string":
continue continue
if feature_info["dtype"] in ["image", "video"]: data = torch.stack(data).cpu().numpy()
image_paths = [] if dataset.features[key]["dtype"] in ["image", "video"]:
for row in episode_slice: axes_to_reduce = (0, 2, 3)
if key in row: keepdims = True
relative_path = row[key]
if isinstance(relative_path, str):
absolute_path = str(dataset.meta.root / relative_path)
image_paths.append(absolute_path)
if image_paths:
episode_data[key] = image_paths
else: else:
arrays = [] axes_to_reduce = 0
for row in episode_slice: keepdims = data.ndim == 1
if key in row:
arrays.append(np.array(row[key]))
if arrays: ep_stats[key] = get_feature_stats(
episode_data[key] = np.stack(arrays) data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
)
return episode_data if dataset.features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]: def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
@@ -127,11 +122,25 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes") logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
episode_stats_list = [] episode_stats_list = []
max_workers = min(dataset.num_episodes, 8)
for episode_idx in range(dataset.num_episodes): with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
episode_data = load_episode_data(dataset, episode_idx) future_to_episode = {
ep_stats = compute_episode_stats(episode_data, dataset.features) executor.submit(process_single_episode, dataset, episode_idx): episode_idx
episode_stats_list.append(ep_stats) for episode_idx in range(dataset.num_episodes)
}
episode_results = {}
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
for future in concurrent.futures.as_completed(future_to_episode):
episode_idx = future_to_episode[future]
ep_stats = future.result()
episode_results[episode_idx] = ep_stats
pbar.update(1)
for episode_idx in range(dataset.num_episodes):
if episode_idx in episode_results:
episode_stats_list.append(episode_results[episode_idx])
if not episode_stats_list: if not episode_stats_list:
raise ValueError("No episode data found for computing statistics") raise ValueError("No episode data found for computing statistics")
@@ -143,12 +152,14 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
def augment_dataset_with_quantile_stats( def augment_dataset_with_quantile_stats(
repo_id: str, repo_id: str,
root: str | Path | None = None, root: str | Path | None = None,
overwrite: bool = False,
) -> None: ) -> None:
"""Augment a dataset with quantile statistics if they are missing. """Augment a dataset with quantile statistics if they are missing.
Args: Args:
repo_id: Repository ID of the dataset repo_id: Repository ID of the dataset
root: Local root directory for the dataset root: Local root directory for the dataset
overwrite: Overwrite existing quantile statistics if they already exist
""" """
logging.info(f"Loading dataset: {repo_id}") logging.info(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset( dataset = LeRobotDataset(
@@ -156,7 +167,7 @@ def augment_dataset_with_quantile_stats(
root=root, root=root,
) )
if has_quantile_stats(dataset.meta.stats): if not overwrite and has_quantile_stats(dataset.meta.stats):
logging.info("Dataset already contains quantile statistics. No action needed.") logging.info("Dataset already contains quantile statistics. No action needed.")
return return
@@ -189,6 +200,11 @@ def main():
type=str, type=str,
help="Local root directory for the dataset", help="Local root directory for the dataset",
) )
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing quantile statistics if they already exist",
)
args = parser.parse_args() args = parser.parse_args()
root = Path(args.root) if args.root else None root = Path(args.root) if args.root else None
@@ -198,6 +214,7 @@ def main():
augment_dataset_with_quantile_stats( augment_dataset_with_quantile_stats(
repo_id=args.repo_id, repo_id=args.repo_id,
root=root, root=root,
overwrite=args.overwrite,
) )
+4 -4
View File
@@ -288,8 +288,8 @@ class _NormalizationMixin:
Normalization Modes: Normalization Modes:
- MEAN_STD: Centers data around zero with unit variance. - MEAN_STD: Centers data around zero with unit variance.
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values. - MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
- QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99). - QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99).
- QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90). - QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90).
Args: Args:
tensor: The input tensor to transform. tensor: The input tensor to transform.
@@ -375,7 +375,7 @@ class _NormalizationMixin:
) )
if inverse: if inverse:
return tensor * denom + q01 return tensor * denom + q01
return (tensor - q01) / denom return 2.0 * (tensor - q01) / denom - 1.0
if norm_mode == NormalizationMode.QUANTILE10: if norm_mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10", None) q10 = stats.get("q10", None)
@@ -392,7 +392,7 @@ class _NormalizationMixin:
) )
if inverse: if inverse:
return tensor * denom + q10 return tensor * denom + q10
return (tensor - q10) / denom return 2.0 * (tensor - q10) / denom - 1.0
# If necessary stats are missing, return input unchanged. # If necessary stats are missing, return input unchanged.
return tensor return tensor