mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
Fix/quantiles script (#2064)
* refactor augment stats with quantiles script add parallelization for faster processing shift the quantile normalization between -1 1 * fix replay buffer tests * fix comment
This commit is contained in:
@@ -315,9 +315,8 @@ def _reshape_for_global_stats(
|
||||
if keepdims:
|
||||
target_shape = tuple(1 for _ in original_shape)
|
||||
return value.reshape(target_shape)
|
||||
elif not keepdims and value.ndim > 0 and value.size == 1:
|
||||
return value.item()
|
||||
return value
|
||||
# Keep at least 1-D arrays to satisfy validator
|
||||
return np.atleast_1d(value)
|
||||
|
||||
|
||||
def _reshape_single_stat(
|
||||
@@ -410,12 +409,6 @@ def _compute_basic_stats(
|
||||
"count": np.array([sample_count]),
|
||||
}
|
||||
|
||||
# For single-element arrays with shape (1,1), convert to scalar arrays
|
||||
if array.shape == (1, 1):
|
||||
for key in stats:
|
||||
if key != "count" and stats[key].size == 1:
|
||||
stats[key] = np.array(stats[key].item())
|
||||
|
||||
for q in quantile_list_keys:
|
||||
stats[q] = stats["mean"].copy()
|
||||
|
||||
@@ -470,12 +463,6 @@ def get_feature_stats(
|
||||
stats = running_stats.get_statistics()
|
||||
stats["count"] = np.array([sample_count])
|
||||
|
||||
# For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays
|
||||
if axis is None and reshaped.shape[1] == 1:
|
||||
for key in stats:
|
||||
if key != "count" and stats[key].size == 1:
|
||||
stats[key] = np.array(stats[key].item())
|
||||
|
||||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||||
return stats
|
||||
|
||||
|
||||
@@ -34,12 +34,15 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import write_stats
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -67,52 +70,44 @@ def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[s
|
||||
return False
|
||||
|
||||
|
||||
def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
"""Load episode data by accessing the underlying HuggingFace dataset.
|
||||
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
"""Process a single episode and return its statistics.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset
|
||||
episode_idx: Index of the episode to load
|
||||
episode_idx: Index of the episode to process
|
||||
|
||||
Returns:
|
||||
Dictionary containing episode data for each feature
|
||||
Dictionary containing episode statistics
|
||||
"""
|
||||
logging.info(f"Computing stats for episode {episode_idx}")
|
||||
|
||||
episode_info = dataset.meta.episodes[episode_idx]
|
||||
episode_length = episode_info["length"]
|
||||
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
||||
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
||||
|
||||
start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx))
|
||||
end_idx = start_idx + episode_length
|
||||
|
||||
episode_data = {}
|
||||
|
||||
episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx))
|
||||
|
||||
for key, feature_info in dataset.features.items():
|
||||
if feature_info["dtype"] == "string":
|
||||
ep_stats = {}
|
||||
for key, data in dataset.hf_dataset[start_idx:end_idx].items():
|
||||
if dataset.features[key]["dtype"] == "string":
|
||||
continue
|
||||
|
||||
if feature_info["dtype"] in ["image", "video"]:
|
||||
image_paths = []
|
||||
for row in episode_slice:
|
||||
if key in row:
|
||||
relative_path = row[key]
|
||||
if isinstance(relative_path, str):
|
||||
absolute_path = str(dataset.meta.root / relative_path)
|
||||
image_paths.append(absolute_path)
|
||||
|
||||
if image_paths:
|
||||
episode_data[key] = image_paths
|
||||
data = torch.stack(data).cpu().numpy()
|
||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
arrays = []
|
||||
for row in episode_slice:
|
||||
if key in row:
|
||||
arrays.append(np.array(row[key]))
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
if arrays:
|
||||
episode_data[key] = np.stack(arrays)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
|
||||
)
|
||||
|
||||
return episode_data
|
||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
||||
@@ -127,11 +122,25 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
|
||||
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
||||
|
||||
episode_stats_list = []
|
||||
max_workers = min(dataset.num_episodes, 8)
|
||||
|
||||
for episode_idx in range(dataset.num_episodes):
|
||||
episode_data = load_episode_data(dataset, episode_idx)
|
||||
ep_stats = compute_episode_stats(episode_data, dataset.features)
|
||||
episode_stats_list.append(ep_stats)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_episode = {
|
||||
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
||||
for episode_idx in range(dataset.num_episodes)
|
||||
}
|
||||
|
||||
episode_results = {}
|
||||
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_episode):
|
||||
episode_idx = future_to_episode[future]
|
||||
ep_stats = future.result()
|
||||
episode_results[episode_idx] = ep_stats
|
||||
pbar.update(1)
|
||||
|
||||
for episode_idx in range(dataset.num_episodes):
|
||||
if episode_idx in episode_results:
|
||||
episode_stats_list.append(episode_results[episode_idx])
|
||||
|
||||
if not episode_stats_list:
|
||||
raise ValueError("No episode data found for computing statistics")
|
||||
@@ -143,12 +152,14 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
|
||||
def augment_dataset_with_quantile_stats(
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Augment a dataset with quantile statistics if they are missing.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID of the dataset
|
||||
root: Local root directory for the dataset
|
||||
overwrite: Overwrite existing quantile statistics if they already exist
|
||||
"""
|
||||
logging.info(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(
|
||||
@@ -156,7 +167,7 @@ def augment_dataset_with_quantile_stats(
|
||||
root=root,
|
||||
)
|
||||
|
||||
if has_quantile_stats(dataset.meta.stats):
|
||||
if not overwrite and has_quantile_stats(dataset.meta.stats):
|
||||
logging.info("Dataset already contains quantile statistics. No action needed.")
|
||||
return
|
||||
|
||||
@@ -189,6 +200,11 @@ def main():
|
||||
type=str,
|
||||
help="Local root directory for the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="Overwrite existing quantile statistics if they already exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
root = Path(args.root) if args.root else None
|
||||
@@ -198,6 +214,7 @@ def main():
|
||||
augment_dataset_with_quantile_stats(
|
||||
repo_id=args.repo_id,
|
||||
root=root,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -288,8 +288,8 @@ class _NormalizationMixin:
|
||||
Normalization Modes:
|
||||
- MEAN_STD: Centers data around zero with unit variance.
|
||||
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
|
||||
- QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99).
|
||||
- QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90).
|
||||
- QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99).
|
||||
- QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90).
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to transform.
|
||||
@@ -375,7 +375,7 @@ class _NormalizationMixin:
|
||||
)
|
||||
if inverse:
|
||||
return tensor * denom + q01
|
||||
return (tensor - q01) / denom
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
if norm_mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10", None)
|
||||
@@ -392,7 +392,7 @@ class _NormalizationMixin:
|
||||
)
|
||||
if inverse:
|
||||
return tensor * denom + q10
|
||||
return (tensor - q10) / denom
|
||||
return 2.0 * (tensor - q10) / denom - 1.0
|
||||
|
||||
# If necessary stats are missing, return input unchanged.
|
||||
return tensor
|
||||
|
||||
Reference in New Issue
Block a user