diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 9fee84487..fe6626689 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -315,9 +315,8 @@ def _reshape_for_global_stats( if keepdims: target_shape = tuple(1 for _ in original_shape) return value.reshape(target_shape) - elif not keepdims and value.ndim > 0 and value.size == 1: - return value.item() - return value + # Keep at least 1-D arrays to satisfy validator + return np.atleast_1d(value) def _reshape_single_stat( @@ -410,12 +409,6 @@ def _compute_basic_stats( "count": np.array([sample_count]), } - # For single-element arrays with shape (1,1), convert to scalar arrays - if array.shape == (1, 1): - for key in stats: - if key != "count" and stats[key].size == 1: - stats[key] = np.array(stats[key].item()) - for q in quantile_list_keys: stats[q] = stats["mean"].copy() @@ -470,12 +463,6 @@ def get_feature_stats( stats = running_stats.get_statistics() stats["count"] = np.array([sample_count]) - # For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays - if axis is None and reshaped.shape[1] == 1: - for key in stats: - if key != "count" and stats[key].size == 1: - stats[key] = np.array(stats[key].item()) - stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape) return stats diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py index a5247a728..c2d38f017 100644 --- a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py +++ b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py @@ -34,12 +34,15 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ """ import argparse +import concurrent.futures import logging from pathlib import Path import numpy as np +import torch +from tqdm import tqdm -from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats +from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import write_stats from lerobot.utils.utils import init_logging @@ -67,52 +70,44 @@ def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[s return False -def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict: - """Load episode data by accessing the underlying HuggingFace dataset. +def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict: + """Process a single episode and return its statistics. Args: dataset: The LeRobot dataset - episode_idx: Index of the episode to load + episode_idx: Index of the episode to process Returns: - Dictionary containing episode data for each feature + Dictionary containing episode statistics """ + logging.info(f"Computing stats for episode {episode_idx}") - episode_info = dataset.meta.episodes[episode_idx] - episode_length = episode_info["length"] + start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"] + end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"] - start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx)) - end_idx = start_idx + episode_length - - episode_data = {} - - episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx)) - - for key, feature_info in dataset.features.items(): - if feature_info["dtype"] == "string": + ep_stats = {} + for key, data in dataset.hf_dataset[start_idx:end_idx].items(): + if dataset.features[key]["dtype"] == "string": continue - if feature_info["dtype"] in ["image", "video"]: - image_paths = [] - for row in episode_slice: - if key in row: - relative_path = row[key] - if isinstance(relative_path, str): - absolute_path = str(dataset.meta.root / relative_path) - image_paths.append(absolute_path) - - if image_paths: - episode_data[key] = image_paths + data = torch.stack(data).cpu().numpy() + if dataset.features[key]["dtype"] in ["image", "video"]: + axes_to_reduce = (0, 2, 3) + keepdims = True else: - arrays = [] - for row in episode_slice: - if key in row: - arrays.append(np.array(row[key])) + axes_to_reduce = 0 + keepdims = data.ndim == 1 - if arrays: - episode_data[key] = np.stack(arrays) + ep_stats[key] = get_feature_stats( + data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES + ) - return episode_data + if dataset.features[key]["dtype"] in ["image", "video"]: + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + } + + return ep_stats def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]: @@ -127,11 +122,25 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes") episode_stats_list = [] + max_workers = min(dataset.num_episodes, 8) - for episode_idx in range(dataset.num_episodes): - episode_data = load_episode_data(dataset, episode_idx) - ep_stats = compute_episode_stats(episode_data, dataset.features) - episode_stats_list.append(ep_stats) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_episode = { + executor.submit(process_single_episode, dataset, episode_idx): episode_idx + for episode_idx in range(dataset.num_episodes) + } + + episode_results = {} + with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar: + for future in concurrent.futures.as_completed(future_to_episode): + episode_idx = future_to_episode[future] + ep_stats = future.result() + episode_results[episode_idx] = ep_stats + pbar.update(1) + + for episode_idx in range(dataset.num_episodes): + if episode_idx in episode_results: + episode_stats_list.append(episode_results[episode_idx]) if not episode_stats_list: raise ValueError("No episode data found for computing statistics") @@ -143,12 +152,14 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic def augment_dataset_with_quantile_stats( repo_id: str, root: str | Path | None = None, + overwrite: bool = False, ) -> None: """Augment a dataset with quantile statistics if they are missing. Args: repo_id: Repository ID of the dataset root: Local root directory for the dataset + overwrite: Overwrite existing quantile statistics if they already exist """ logging.info(f"Loading dataset: {repo_id}") dataset = LeRobotDataset( @@ -156,7 +167,7 @@ def augment_dataset_with_quantile_stats( root=root, ) - if has_quantile_stats(dataset.meta.stats): + if not overwrite and has_quantile_stats(dataset.meta.stats): logging.info("Dataset already contains quantile statistics. No action needed.") return @@ -189,6 +200,11 @@ def main(): type=str, help="Local root directory for the dataset", ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing quantile statistics if they already exist", + ) args = parser.parse_args() root = Path(args.root) if args.root else None @@ -198,6 +214,7 @@ def main(): augment_dataset_with_quantile_stats( repo_id=args.repo_id, root=root, + overwrite=args.overwrite, ) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 8a2398a5d..110816e4a 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -288,8 +288,8 @@ class _NormalizationMixin: Normalization Modes: - MEAN_STD: Centers data around zero with unit variance. - MIN_MAX: Scales data to [-1, 1] range using actual min/max values. - - QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99). - - QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90). + - QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99). + - QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90). Args: tensor: The input tensor to transform. @@ -375,7 +375,7 @@ class _NormalizationMixin: ) if inverse: return tensor * denom + q01 - return (tensor - q01) / denom + return 2.0 * (tensor - q01) / denom - 1.0 if norm_mode == NormalizationMode.QUANTILE10: q10 = stats.get("q10", None) @@ -392,7 +392,7 @@ class _NormalizationMixin: ) if inverse: return tensor * denom + q10 - return (tensor - q10) / denom + return 2.0 * (tensor - q10) / denom - 1.0 # If necessary stats are missing, return input unchanged. return tensor