mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
Fix/quantiles script (#2064)
* refactor augment stats with quantiles script add parallelization for faster processing shift the quantile normalization between -1 1 * fix replay buffer tests * fix comment
This commit is contained in:
@@ -315,9 +315,8 @@ def _reshape_for_global_stats(
|
|||||||
if keepdims:
|
if keepdims:
|
||||||
target_shape = tuple(1 for _ in original_shape)
|
target_shape = tuple(1 for _ in original_shape)
|
||||||
return value.reshape(target_shape)
|
return value.reshape(target_shape)
|
||||||
elif not keepdims and value.ndim > 0 and value.size == 1:
|
# Keep at least 1-D arrays to satisfy validator
|
||||||
return value.item()
|
return np.atleast_1d(value)
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _reshape_single_stat(
|
def _reshape_single_stat(
|
||||||
@@ -410,12 +409,6 @@ def _compute_basic_stats(
|
|||||||
"count": np.array([sample_count]),
|
"count": np.array([sample_count]),
|
||||||
}
|
}
|
||||||
|
|
||||||
# For single-element arrays with shape (1,1), convert to scalar arrays
|
|
||||||
if array.shape == (1, 1):
|
|
||||||
for key in stats:
|
|
||||||
if key != "count" and stats[key].size == 1:
|
|
||||||
stats[key] = np.array(stats[key].item())
|
|
||||||
|
|
||||||
for q in quantile_list_keys:
|
for q in quantile_list_keys:
|
||||||
stats[q] = stats["mean"].copy()
|
stats[q] = stats["mean"].copy()
|
||||||
|
|
||||||
@@ -470,12 +463,6 @@ def get_feature_stats(
|
|||||||
stats = running_stats.get_statistics()
|
stats = running_stats.get_statistics()
|
||||||
stats["count"] = np.array([sample_count])
|
stats["count"] = np.array([sample_count])
|
||||||
|
|
||||||
# For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays
|
|
||||||
if axis is None and reshaped.shape[1] == 1:
|
|
||||||
for key in stats:
|
|
||||||
if key != "count" and stats[key].size == 1:
|
|
||||||
stats[key] = np.array(stats[key].item())
|
|
||||||
|
|
||||||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|||||||
@@ -34,12 +34,15 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats
|
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import write_stats
|
from lerobot.datasets.utils import write_stats
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
@@ -67,52 +70,44 @@ def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[s
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||||
"""Load episode data by accessing the underlying HuggingFace dataset.
|
"""Process a single episode and return its statistics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: The LeRobot dataset
|
dataset: The LeRobot dataset
|
||||||
episode_idx: Index of the episode to load
|
episode_idx: Index of the episode to process
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing episode data for each feature
|
Dictionary containing episode statistics
|
||||||
"""
|
"""
|
||||||
|
logging.info(f"Computing stats for episode {episode_idx}")
|
||||||
|
|
||||||
episode_info = dataset.meta.episodes[episode_idx]
|
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
||||||
episode_length = episode_info["length"]
|
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
||||||
|
|
||||||
start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx))
|
ep_stats = {}
|
||||||
end_idx = start_idx + episode_length
|
for key, data in dataset.hf_dataset[start_idx:end_idx].items():
|
||||||
|
if dataset.features[key]["dtype"] == "string":
|
||||||
episode_data = {}
|
|
||||||
|
|
||||||
episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx))
|
|
||||||
|
|
||||||
for key, feature_info in dataset.features.items():
|
|
||||||
if feature_info["dtype"] == "string":
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if feature_info["dtype"] in ["image", "video"]:
|
data = torch.stack(data).cpu().numpy()
|
||||||
image_paths = []
|
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||||
for row in episode_slice:
|
axes_to_reduce = (0, 2, 3)
|
||||||
if key in row:
|
keepdims = True
|
||||||
relative_path = row[key]
|
|
||||||
if isinstance(relative_path, str):
|
|
||||||
absolute_path = str(dataset.meta.root / relative_path)
|
|
||||||
image_paths.append(absolute_path)
|
|
||||||
|
|
||||||
if image_paths:
|
|
||||||
episode_data[key] = image_paths
|
|
||||||
else:
|
else:
|
||||||
arrays = []
|
axes_to_reduce = 0
|
||||||
for row in episode_slice:
|
keepdims = data.ndim == 1
|
||||||
if key in row:
|
|
||||||
arrays.append(np.array(row[key]))
|
|
||||||
|
|
||||||
if arrays:
|
ep_stats[key] = get_feature_stats(
|
||||||
episode_data[key] = np.stack(arrays)
|
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
|
||||||
|
)
|
||||||
|
|
||||||
return episode_data
|
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||||
|
ep_stats[key] = {
|
||||||
|
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ep_stats
|
||||||
|
|
||||||
|
|
||||||
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
||||||
@@ -127,11 +122,25 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
|
|||||||
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
||||||
|
|
||||||
episode_stats_list = []
|
episode_stats_list = []
|
||||||
|
max_workers = min(dataset.num_episodes, 8)
|
||||||
|
|
||||||
for episode_idx in range(dataset.num_episodes):
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
episode_data = load_episode_data(dataset, episode_idx)
|
future_to_episode = {
|
||||||
ep_stats = compute_episode_stats(episode_data, dataset.features)
|
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
||||||
episode_stats_list.append(ep_stats)
|
for episode_idx in range(dataset.num_episodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
episode_results = {}
|
||||||
|
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
||||||
|
for future in concurrent.futures.as_completed(future_to_episode):
|
||||||
|
episode_idx = future_to_episode[future]
|
||||||
|
ep_stats = future.result()
|
||||||
|
episode_results[episode_idx] = ep_stats
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
for episode_idx in range(dataset.num_episodes):
|
||||||
|
if episode_idx in episode_results:
|
||||||
|
episode_stats_list.append(episode_results[episode_idx])
|
||||||
|
|
||||||
if not episode_stats_list:
|
if not episode_stats_list:
|
||||||
raise ValueError("No episode data found for computing statistics")
|
raise ValueError("No episode data found for computing statistics")
|
||||||
@@ -143,12 +152,14 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
|
|||||||
def augment_dataset_with_quantile_stats(
|
def augment_dataset_with_quantile_stats(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
|
overwrite: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Augment a dataset with quantile statistics if they are missing.
|
"""Augment a dataset with quantile statistics if they are missing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id: Repository ID of the dataset
|
repo_id: Repository ID of the dataset
|
||||||
root: Local root directory for the dataset
|
root: Local root directory for the dataset
|
||||||
|
overwrite: Overwrite existing quantile statistics if they already exist
|
||||||
"""
|
"""
|
||||||
logging.info(f"Loading dataset: {repo_id}")
|
logging.info(f"Loading dataset: {repo_id}")
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
@@ -156,7 +167,7 @@ def augment_dataset_with_quantile_stats(
|
|||||||
root=root,
|
root=root,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_quantile_stats(dataset.meta.stats):
|
if not overwrite and has_quantile_stats(dataset.meta.stats):
|
||||||
logging.info("Dataset already contains quantile statistics. No action needed.")
|
logging.info("Dataset already contains quantile statistics. No action needed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -189,6 +200,11 @@ def main():
|
|||||||
type=str,
|
type=str,
|
||||||
help="Local root directory for the dataset",
|
help="Local root directory for the dataset",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite",
|
||||||
|
action="store_true",
|
||||||
|
help="Overwrite existing quantile statistics if they already exist",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
root = Path(args.root) if args.root else None
|
root = Path(args.root) if args.root else None
|
||||||
@@ -198,6 +214,7 @@ def main():
|
|||||||
augment_dataset_with_quantile_stats(
|
augment_dataset_with_quantile_stats(
|
||||||
repo_id=args.repo_id,
|
repo_id=args.repo_id,
|
||||||
root=root,
|
root=root,
|
||||||
|
overwrite=args.overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -288,8 +288,8 @@ class _NormalizationMixin:
|
|||||||
Normalization Modes:
|
Normalization Modes:
|
||||||
- MEAN_STD: Centers data around zero with unit variance.
|
- MEAN_STD: Centers data around zero with unit variance.
|
||||||
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
|
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
|
||||||
- QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99).
|
- QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99).
|
||||||
- QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90).
|
- QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: The input tensor to transform.
|
tensor: The input tensor to transform.
|
||||||
@@ -375,7 +375,7 @@ class _NormalizationMixin:
|
|||||||
)
|
)
|
||||||
if inverse:
|
if inverse:
|
||||||
return tensor * denom + q01
|
return tensor * denom + q01
|
||||||
return (tensor - q01) / denom
|
return 2.0 * (tensor - q01) / denom - 1.0
|
||||||
|
|
||||||
if norm_mode == NormalizationMode.QUANTILE10:
|
if norm_mode == NormalizationMode.QUANTILE10:
|
||||||
q10 = stats.get("q10", None)
|
q10 = stats.get("q10", None)
|
||||||
@@ -392,7 +392,7 @@ class _NormalizationMixin:
|
|||||||
)
|
)
|
||||||
if inverse:
|
if inverse:
|
||||||
return tensor * denom + q10
|
return tensor * denom + q10
|
||||||
return (tensor - q10) / denom
|
return 2.0 * (tensor - q10) / denom - 1.0
|
||||||
|
|
||||||
# If necessary stats are missing, return input unchanged.
|
# If necessary stats are missing, return input unchanged.
|
||||||
return tensor
|
return tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user