mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
279 lines
9.1 KiB
Python
279 lines
9.1 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Script to find episodes with highest MSE between observation.state and action pairs.
|
|
|
|
This script:
|
|
1. Downloads a LeRobot dataset (if needed, skipping videos)
|
|
2. Computes MSE between observation.state and action for each frame
|
|
3. Aggregates MSE per episode
|
|
4. Returns the top 1% episodes with highest total MSE
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
def compute_episode_mse(
|
|
dataset: LeRobotDataset,
|
|
state_key: str = "observation.state",
|
|
action_key: str = "action",
|
|
) -> dict[int, float]:
|
|
"""
|
|
Compute total MSE between state and action for each episode.
|
|
|
|
Args:
|
|
dataset: LeRobotDataset to analyze
|
|
state_key: Key for the observation state in the dataset
|
|
action_key: Key for the action in the dataset
|
|
|
|
Returns:
|
|
Dictionary mapping episode_index to total MSE for that episode
|
|
"""
|
|
episode_mse = {}
|
|
|
|
# Get all unique episode indices
|
|
hf_dataset = dataset.hf_dataset
|
|
|
|
# Group frames by episode for efficient processing
|
|
logging.info("Computing MSE for each episode...")
|
|
|
|
# Process all frames and accumulate MSE per episode
|
|
for idx in tqdm(range(len(hf_dataset)), desc="Processing frames"):
|
|
item = hf_dataset[idx]
|
|
|
|
ep_idx = item["episode_index"]
|
|
if isinstance(ep_idx, torch.Tensor):
|
|
ep_idx = ep_idx.item()
|
|
|
|
state = item[state_key]
|
|
action = item[action_key]
|
|
|
|
if isinstance(state, torch.Tensor):
|
|
state = state.numpy()
|
|
if isinstance(action, torch.Tensor):
|
|
action = action.numpy()
|
|
|
|
# Compute MSE for this frame (sum of squared differences across all dimensions)
|
|
mse = np.mean((state - action) ** 2)
|
|
|
|
if ep_idx not in episode_mse:
|
|
episode_mse[ep_idx] = 0.0
|
|
episode_mse[ep_idx] += mse
|
|
|
|
return episode_mse
|
|
|
|
|
|
def get_top_mse_episodes(
|
|
episode_mse: dict[int, float],
|
|
top_percent: float = 1.0,
|
|
) -> list[int]:
|
|
"""
|
|
Get the top X% of episodes with highest total MSE.
|
|
|
|
Args:
|
|
episode_mse: Dictionary mapping episode_index to total MSE
|
|
top_percent: Percentage of episodes to return (default: 1%)
|
|
|
|
Returns:
|
|
List of episode indices sorted by MSE (highest first)
|
|
"""
|
|
# Sort episodes by MSE in descending order
|
|
sorted_episodes = sorted(episode_mse.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
# Calculate number of episodes to return
|
|
num_episodes = len(sorted_episodes)
|
|
num_top = max(1, int(np.ceil(num_episodes * top_percent / 100)))
|
|
|
|
# Extract top episode indices
|
|
top_episodes = [ep_idx for ep_idx, _ in sorted_episodes[:num_top]]
|
|
|
|
return top_episodes
|
|
|
|
|
|
def find_high_mse_episodes(
|
|
repo_id: str,
|
|
root: str | Path | None = None,
|
|
state_key: str = "observation.state",
|
|
action_key: str = "action",
|
|
top_percent: float = 1.0,
|
|
force_download: bool = False,
|
|
) -> tuple[list[int], dict[int, float]]:
|
|
"""
|
|
Find episodes with highest MSE between observation.state and action.
|
|
|
|
Args:
|
|
repo_id: HuggingFace dataset repository ID
|
|
root: Local directory for dataset storage (default: ~/.cache/huggingface/lerobot)
|
|
state_key: Key for the observation state in the dataset
|
|
action_key: Key for the action in the dataset
|
|
top_percent: Percentage of episodes to return (default: 1%)
|
|
force_download: Force re-download of the dataset
|
|
|
|
Returns:
|
|
Tuple of (list of top episode indices, dict of all episode MSEs)
|
|
"""
|
|
logging.info(f"Loading dataset: {repo_id}")
|
|
|
|
# Load the dataset (skip video download since we only need state/action data)
|
|
dataset = LeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=root,
|
|
download_videos=False,
|
|
force_cache_sync=force_download,
|
|
)
|
|
|
|
# Verify the dataset has the required features
|
|
if state_key not in dataset.features:
|
|
raise ValueError(f"Dataset does not contain '{state_key}' feature. "
|
|
f"Available features: {list(dataset.features.keys())}")
|
|
if action_key not in dataset.features:
|
|
raise ValueError(f"Dataset does not contain '{action_key}' feature. "
|
|
f"Available features: {list(dataset.features.keys())}")
|
|
|
|
# Check that state and action have the same shape
|
|
state_shape = tuple(dataset.features[state_key]["shape"])
|
|
action_shape = tuple(dataset.features[action_key]["shape"])
|
|
if state_shape != action_shape:
|
|
raise ValueError(f"State shape {state_shape} does not match action shape {action_shape}")
|
|
|
|
logging.info(f"Dataset loaded successfully:")
|
|
logging.info(f" - Total episodes: {dataset.meta.total_episodes}")
|
|
logging.info(f" - Total frames: {dataset.meta.total_frames}")
|
|
logging.info(f" - State shape: {state_shape}")
|
|
logging.info(f" - Action shape: {action_shape}")
|
|
logging.info(f" - Feature names: {dataset.features[state_key].get('names', 'N/A')}")
|
|
|
|
# Compute MSE for each episode
|
|
episode_mse = compute_episode_mse(dataset, state_key, action_key)
|
|
|
|
# Get top episodes
|
|
top_episodes = get_top_mse_episodes(episode_mse, top_percent)
|
|
|
|
return top_episodes, episode_mse
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Find episodes with highest MSE between observation.state and action"
|
|
)
|
|
parser.add_argument(
|
|
"repo_id",
|
|
type=str,
|
|
help="HuggingFace dataset repository ID (e.g., 'lerobot/aloha_sim_insertion_human')",
|
|
)
|
|
parser.add_argument(
|
|
"--root",
|
|
type=str,
|
|
default=None,
|
|
help="Local directory for dataset storage (default: ~/.cache/huggingface/lerobot)",
|
|
)
|
|
parser.add_argument(
|
|
"--state-key",
|
|
type=str,
|
|
default="observation.state",
|
|
help="Key for observation state feature (default: 'observation.state')",
|
|
)
|
|
parser.add_argument(
|
|
"--action-key",
|
|
type=str,
|
|
default="action",
|
|
help="Key for action feature (default: 'action')",
|
|
)
|
|
parser.add_argument(
|
|
"--top-percent",
|
|
type=float,
|
|
default=1.0,
|
|
help="Percentage of episodes to return (default: 1.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--force-download",
|
|
action="store_true",
|
|
help="Force re-download of the dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--show-all-mse",
|
|
action="store_true",
|
|
help="Show MSE values for all episodes",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default=None,
|
|
help="Output file to save results (optional)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Find high MSE episodes
|
|
top_episodes, all_mse = find_high_mse_episodes(
|
|
repo_id=args.repo_id,
|
|
root=args.root,
|
|
state_key=args.state_key,
|
|
action_key=args.action_key,
|
|
top_percent=args.top_percent,
|
|
force_download=args.force_download,
|
|
)
|
|
|
|
# Print results
|
|
print("\n" + "=" * 60)
|
|
print(f"TOP {args.top_percent}% EPISODES WITH HIGHEST MSE")
|
|
print("=" * 60)
|
|
|
|
print(f"\nTotal episodes analyzed: {len(all_mse)}")
|
|
print(f"Number of top episodes (top {args.top_percent}%): {len(top_episodes)}")
|
|
|
|
print(f"\nTop {len(top_episodes)} episode(s) with highest MSE:")
|
|
print("-" * 40)
|
|
for i, ep_idx in enumerate(top_episodes, 1):
|
|
print(f" {i:3d}. Episode {ep_idx:5d} - Total MSE: {all_mse[ep_idx]:.6f}")
|
|
|
|
# Statistics
|
|
all_mse_values = list(all_mse.values())
|
|
print(f"\nMSE Statistics:")
|
|
print(f" - Mean MSE: {np.mean(all_mse_values):.6f}")
|
|
print(f" - Std MSE: {np.std(all_mse_values):.6f}")
|
|
print(f" - Min MSE: {np.min(all_mse_values):.6f}")
|
|
print(f" - Max MSE: {np.max(all_mse_values):.6f}")
|
|
print(f" - Median MSE: {np.median(all_mse_values):.6f}")
|
|
|
|
if args.show_all_mse:
|
|
print(f"\nAll episodes sorted by MSE (descending):")
|
|
print("-" * 40)
|
|
sorted_episodes = sorted(all_mse.items(), key=lambda x: x[1], reverse=True)
|
|
for ep_idx, mse in sorted_episodes:
|
|
print(f" Episode {ep_idx:5d} - Total MSE: {mse:.6f}")
|
|
|
|
# Save results if output file specified
|
|
if args.output:
|
|
output_path = Path(args.output)
|
|
with open(output_path, "w") as f:
|
|
f.write(f"# High MSE Episodes Analysis\n")
|
|
f.write(f"# Dataset: {args.repo_id}\n")
|
|
f.write(f"# State key: {args.state_key}\n")
|
|
f.write(f"# Action key: {args.action_key}\n")
|
|
f.write(f"# Top percent: {args.top_percent}%\n\n")
|
|
|
|
f.write(f"Top {args.top_percent}% episodes:\n")
|
|
for ep_idx in top_episodes:
|
|
f.write(f"{ep_idx},{all_mse[ep_idx]:.6f}\n")
|
|
|
|
logging.info(f"Results saved to: {output_path}")
|
|
|
|
# Return the list for programmatic use
|
|
return top_episodes
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|