mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
add script
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Script to find episodes with highest MSE between observation.state and action pairs.
|
||||
|
||||
This script:
|
||||
1. Downloads a LeRobot dataset (if needed, skipping videos)
|
||||
2. Computes MSE between observation.state and action for each frame
|
||||
3. Aggregates MSE per episode
|
||||
4. Returns the top 1% episodes with highest total MSE
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
|
||||
def compute_episode_mse(
|
||||
dataset: LeRobotDataset,
|
||||
state_key: str = "observation.state",
|
||||
action_key: str = "action",
|
||||
) -> dict[int, float]:
|
||||
"""
|
||||
Compute total MSE between state and action for each episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset to analyze
|
||||
state_key: Key for the observation state in the dataset
|
||||
action_key: Key for the action in the dataset
|
||||
|
||||
Returns:
|
||||
Dictionary mapping episode_index to total MSE for that episode
|
||||
"""
|
||||
episode_mse = {}
|
||||
|
||||
# Get all unique episode indices
|
||||
hf_dataset = dataset.hf_dataset
|
||||
|
||||
# Group frames by episode for efficient processing
|
||||
logging.info("Computing MSE for each episode...")
|
||||
|
||||
# Process all frames and accumulate MSE per episode
|
||||
for idx in tqdm(range(len(hf_dataset)), desc="Processing frames"):
|
||||
item = hf_dataset[idx]
|
||||
|
||||
ep_idx = item["episode_index"]
|
||||
if isinstance(ep_idx, torch.Tensor):
|
||||
ep_idx = ep_idx.item()
|
||||
|
||||
state = item[state_key]
|
||||
action = item[action_key]
|
||||
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.numpy()
|
||||
if isinstance(action, torch.Tensor):
|
||||
action = action.numpy()
|
||||
|
||||
# Compute MSE for this frame (sum of squared differences across all dimensions)
|
||||
mse = np.mean((state - action) ** 2)
|
||||
|
||||
if ep_idx not in episode_mse:
|
||||
episode_mse[ep_idx] = 0.0
|
||||
episode_mse[ep_idx] += mse
|
||||
|
||||
return episode_mse
|
||||
|
||||
|
||||
def get_top_mse_episodes(
|
||||
episode_mse: dict[int, float],
|
||||
top_percent: float = 1.0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the top X% of episodes with highest total MSE.
|
||||
|
||||
Args:
|
||||
episode_mse: Dictionary mapping episode_index to total MSE
|
||||
top_percent: Percentage of episodes to return (default: 1%)
|
||||
|
||||
Returns:
|
||||
List of episode indices sorted by MSE (highest first)
|
||||
"""
|
||||
# Sort episodes by MSE in descending order
|
||||
sorted_episodes = sorted(episode_mse.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Calculate number of episodes to return
|
||||
num_episodes = len(sorted_episodes)
|
||||
num_top = max(1, int(np.ceil(num_episodes * top_percent / 100)))
|
||||
|
||||
# Extract top episode indices
|
||||
top_episodes = [ep_idx for ep_idx, _ in sorted_episodes[:num_top]]
|
||||
|
||||
return top_episodes
|
||||
|
||||
|
||||
def find_high_mse_episodes(
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
state_key: str = "observation.state",
|
||||
action_key: str = "action",
|
||||
top_percent: float = 1.0,
|
||||
force_download: bool = False,
|
||||
) -> tuple[list[int], dict[int, float]]:
|
||||
"""
|
||||
Find episodes with highest MSE between observation.state and action.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID
|
||||
root: Local directory for dataset storage (default: ~/.cache/huggingface/lerobot)
|
||||
state_key: Key for the observation state in the dataset
|
||||
action_key: Key for the action in the dataset
|
||||
top_percent: Percentage of episodes to return (default: 1%)
|
||||
force_download: Force re-download of the dataset
|
||||
|
||||
Returns:
|
||||
Tuple of (list of top episode indices, dict of all episode MSEs)
|
||||
"""
|
||||
logging.info(f"Loading dataset: {repo_id}")
|
||||
|
||||
# Load the dataset (skip video download since we only need state/action data)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
download_videos=False,
|
||||
force_cache_sync=force_download,
|
||||
)
|
||||
|
||||
# Verify the dataset has the required features
|
||||
if state_key not in dataset.features:
|
||||
raise ValueError(f"Dataset does not contain '{state_key}' feature. "
|
||||
f"Available features: {list(dataset.features.keys())}")
|
||||
if action_key not in dataset.features:
|
||||
raise ValueError(f"Dataset does not contain '{action_key}' feature. "
|
||||
f"Available features: {list(dataset.features.keys())}")
|
||||
|
||||
# Check that state and action have the same shape
|
||||
state_shape = tuple(dataset.features[state_key]["shape"])
|
||||
action_shape = tuple(dataset.features[action_key]["shape"])
|
||||
if state_shape != action_shape:
|
||||
raise ValueError(f"State shape {state_shape} does not match action shape {action_shape}")
|
||||
|
||||
logging.info(f"Dataset loaded successfully:")
|
||||
logging.info(f" - Total episodes: {dataset.meta.total_episodes}")
|
||||
logging.info(f" - Total frames: {dataset.meta.total_frames}")
|
||||
logging.info(f" - State shape: {state_shape}")
|
||||
logging.info(f" - Action shape: {action_shape}")
|
||||
logging.info(f" - Feature names: {dataset.features[state_key].get('names', 'N/A')}")
|
||||
|
||||
# Compute MSE for each episode
|
||||
episode_mse = compute_episode_mse(dataset, state_key, action_key)
|
||||
|
||||
# Get top episodes
|
||||
top_episodes = get_top_mse_episodes(episode_mse, top_percent)
|
||||
|
||||
return top_episodes, episode_mse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Find episodes with highest MSE between observation.state and action"
|
||||
)
|
||||
parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="HuggingFace dataset repository ID (e.g., 'lerobot/aloha_sim_insertion_human')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory for dataset storage (default: ~/.cache/huggingface/lerobot)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default="observation.state",
|
||||
help="Key for observation state feature (default: 'observation.state')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-key",
|
||||
type=str,
|
||||
default="action",
|
||||
help="Key for action feature (default: 'action')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-percent",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Percentage of episodes to return (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-download",
|
||||
action="store_true",
|
||||
help="Force re-download of the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-all-mse",
|
||||
action="store_true",
|
||||
help="Show MSE values for all episodes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output file to save results (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find high MSE episodes
|
||||
top_episodes, all_mse = find_high_mse_episodes(
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
state_key=args.state_key,
|
||||
action_key=args.action_key,
|
||||
top_percent=args.top_percent,
|
||||
force_download=args.force_download,
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print(f"TOP {args.top_percent}% EPISODES WITH HIGHEST MSE")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nTotal episodes analyzed: {len(all_mse)}")
|
||||
print(f"Number of top episodes (top {args.top_percent}%): {len(top_episodes)}")
|
||||
|
||||
print(f"\nTop {len(top_episodes)} episode(s) with highest MSE:")
|
||||
print("-" * 40)
|
||||
for i, ep_idx in enumerate(top_episodes, 1):
|
||||
print(f" {i:3d}. Episode {ep_idx:5d} - Total MSE: {all_mse[ep_idx]:.6f}")
|
||||
|
||||
# Statistics
|
||||
all_mse_values = list(all_mse.values())
|
||||
print(f"\nMSE Statistics:")
|
||||
print(f" - Mean MSE: {np.mean(all_mse_values):.6f}")
|
||||
print(f" - Std MSE: {np.std(all_mse_values):.6f}")
|
||||
print(f" - Min MSE: {np.min(all_mse_values):.6f}")
|
||||
print(f" - Max MSE: {np.max(all_mse_values):.6f}")
|
||||
print(f" - Median MSE: {np.median(all_mse_values):.6f}")
|
||||
|
||||
if args.show_all_mse:
|
||||
print(f"\nAll episodes sorted by MSE (descending):")
|
||||
print("-" * 40)
|
||||
sorted_episodes = sorted(all_mse.items(), key=lambda x: x[1], reverse=True)
|
||||
for ep_idx, mse in sorted_episodes:
|
||||
print(f" Episode {ep_idx:5d} - Total MSE: {mse:.6f}")
|
||||
|
||||
# Save results if output file specified
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(f"# High MSE Episodes Analysis\n")
|
||||
f.write(f"# Dataset: {args.repo_id}\n")
|
||||
f.write(f"# State key: {args.state_key}\n")
|
||||
f.write(f"# Action key: {args.action_key}\n")
|
||||
f.write(f"# Top percent: {args.top_percent}%\n\n")
|
||||
|
||||
f.write(f"Top {args.top_percent}% episodes:\n")
|
||||
for ep_idx in top_episodes:
|
||||
f.write(f"{ep_idx},{all_mse[ep_idx]:.6f}\n")
|
||||
|
||||
logging.info(f"Results saved to: {output_path}")
|
||||
|
||||
# Return the list for programmatic use
|
||||
return top_episodes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user