mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f4aef60ea4 |
@@ -1,26 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Simple script to check buffer naming in the transformed model."""
|
||||
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Load the model with strict=False to see what buffers we have
|
||||
print("Loading model...")
|
||||
policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False)
|
||||
|
||||
# Check what buffer keys exist
|
||||
state_dict = policy.state_dict()
|
||||
buffer_keys = [k for k in state_dict.keys() if "buffer" in k]
|
||||
normalize_keys = [k for k in state_dict.keys() if "normalize" in k]
|
||||
|
||||
print("\nAll buffer keys:")
|
||||
for key in buffer_keys:
|
||||
print(f" {key}")
|
||||
|
||||
print("\nAll normalize keys:")
|
||||
for key in normalize_keys:
|
||||
print(f" {key}")
|
||||
|
||||
print("\nAll keys (first 20):")
|
||||
for i, key in enumerate(state_dict.keys()):
|
||||
if i < 20:
|
||||
print(f" {key}")
|
||||
@@ -50,3 +50,7 @@
|
||||
- local: backwardcomp
|
||||
title: Backward compatibility
|
||||
title: "About"
|
||||
- sections:
|
||||
- local: datasets
|
||||
title: "The LeRobotDataset Format"
|
||||
-title: "Datasets"
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
# The LeRobotDataset Format
|
||||
|
||||
`LeRobotDataset` is a standardized dataset format designed to address the specific needs of robot learning research.
|
||||
In this, it provides a unified and convenient access to robotics data across modalities, including sensorimotor readings, multiple camera feeds and teleoperation status.
|
||||
`LeRobotDataset` also stores general information regarding the data collected, like the task being performed by the teleoperator, the kind of robot used and measurement details like the frames per second at which the recording of both image and robot state's streams are proceeding.
|
||||
|
||||
Therefore, `LeRobotDataset` provides a unified interface for handling multi-modal, time-series data, and it integrates seamlessly with the PyTorch and Hugging Face ecosystems.
|
||||
`LeRobotDataset` is designed to be easily extensible and customizable by users, and it already supports openly available data coming from a variety of embodiments, ranging from manipulator platforms like the SO-100 and ALOHA-2, to real-world humanoid data, simulation datasets and self-driving car datasets.
|
||||
This dataset format is built to be both efficient for training and flexible enough to accommodate the diverse data types encountered in robotics, while promoting reproducibility and ease of use for users.
|
||||
|
||||
## The Format's Design
|
||||
|
||||
A core design choice behind `LeRobotDataset` is separating the underlying data storage from the user-facing API.
|
||||
This allows for efficient serialization and storage while presenting the data in an intuitive, ready-to-use format.
|
||||
A dataset is always organized into three main components:
|
||||
|
||||
1. **Tabular Data**: Low-dimensional, high-frequency data such as joint states, and actions are stored in efficient [Apache Parquet](https://parquet.apache.org/) files, and typically offloaded to the more mature `datasets` library, providing fast, memory-mapped access.
|
||||
2. **Visual Data**: To handle large volumes of camera data, frames are concatenated and encoded into MP4 files. Frames from the same episode are always grouped together into the same video, and multiple videos are grouped together by camera. To reduce stress on the file system, groups of videos for the same camera view are also broke into multiple sub-directories, after a given threshold number.
|
||||
3. **Metadata**: A collection of JSON files which describes the dataset's structure in terms of its metadata, serving as the relational counterpart to both the tabular and visual dimensions of data. Metadata include the different feature schemas, frame rates, normalization statistics, and episode boundaries.
|
||||
|
||||
For scalability, and to support datasets with potentially millions of trajectories resulting in hundreads of millions or billions of individual camera frames, we merge data from different episodes into the same high-level structure.
|
||||
Concretely, this means that any given tabular collection and video will not typically contain information about one episode only, but rather a concatenation of the information available in multiple episodes.
|
||||
This keeps the pressure on the file system, both locally and on remote storage providers like Hugging Face, manageable, at the expense of leveraging more heavily the metadata part of the data, e.g. used to reconstruct information relative to at which position a given episode starts or ends.
|
||||
An example structure for a given `LeRobotDataset` would appear as follows:
|
||||
|
||||
```bash
|
||||
lerobot/svla_so101_pickplace
|
||||
├── data/
|
||||
│ └── chunk-000/
|
||||
│ ├── file_000000.parquet
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── episodes/
|
||||
│ │ ├── chunk-000/
|
||||
│ │ │ └── file_000000.parquet
|
||||
│ │ └── ...
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.jsonl
|
||||
└── videos/
|
||||
└── chunk-000/
|
||||
├── observation.images.wrist_camera/
|
||||
│ ├── file_000000.mp4
|
||||
│ └── ...
|
||||
└── ...
|
||||
```
|
||||
|
||||
- **`meta/info.json`**: This is the central metadata file. It contains the complete dataset schema, defining all features (e.g., `observation.state`, `action`), their shapes, and data types. It also stores crucial information like the dataset's frames-per-second (`fps`), codebase version, and the path templates used to locate data and video files.
|
||||
- **`meta/stats.json`**: This file stores aggregated statistics (mean, std, min, max) for each feature across the entire dataset. These are used for data normalization and are accessible via `dataset.meta.stats`.
|
||||
- **`meta/tasks.jsonl`**: Contains the mapping from natural language task descriptions to integer task indices, which are used for task-conditioned policy training.
|
||||
- **`meta/episodes/`**: This directory contains metadata about each individual episode, such as its length, corresponding task, and pointers to where its data is stored. For scalability, this information is stored in chunked Parquet files rather than a single large JSON file.
|
||||
- **`data/`**: Contains the core frame-by-frame tabular data in Parquet files. To improve performance and handle large datasets, data from **multiple episodes are concatenated into larger files**. These files are organized into chunked subdirectories to keep file sizes manageable. Therefore, a single file typically contains data for more than one episode.
|
||||
- **`videos/`**: Contains the MP4 video files for all visual observation streams. Similar to the `data/` directory, video footage from **multiple episodes is concatenated into single MP4 files**. This strategy significantly reduces the number of files in the dataset, which is more efficient for modern filesystems. The path structure (`/videos/<camera_key>/<chunk>/file_...mp4`) allows the data loader to locate the correct video file and then seek to the precise timestamp for a given frame.
|
||||
|
||||
## Code Example: Using `LeRobotDataset` with `torch.utils.data.DataLoader`
|
||||
|
||||
This section provides an overview of how to access datasets hosted on Hugging Face using the `LeRobotDataset` class.
|
||||
Every dataset on the Hugging Face Hub containing the three main pillars presented above (Tabular and Visual Data, as well as relational Metadata) can be assessed with a single line.
|
||||
Most reinforcement learning (RL) and behavioral cloning (BC) algorithms tend to operate on stack of observation and actions.
|
||||
For instance, RL algorithms typically use a history of previous observations `[o_{t-H}, ..., o_{t}]` to mitigate partial observability.
|
||||
BC cloning algorithms are instead typically trained to regress chunks of multiple actions rather than single controls.
|
||||
To accommodate for the specifics of robot learning training, `LeRobotDataset` provides a native windowing operation, whereby we can use the _seconds_ before and after any given observation using `delta_timestamps`.
|
||||
Non available frames is opportuninely padded, with a padding mask released to provide support in this.
|
||||
Notably, this all happens within the `LeRobotDataset` and is entitrely transparent to higher level wrappers such as `torch.utils.data.DataLoader`.
|
||||
|
||||
Conveniently, by using `LeRobotDataset` with a Pytorch `DataLoader` one can automatically collate the individual sample dictionaries from the dataset into a single dictionary of batched tensors.
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load from the Hugging Face Hub (will be cached locally)
|
||||
dataset = LeRobotDataset("lerobot/svla_so101_pickplace")
|
||||
|
||||
# Get the 100th frame in the dataset by
|
||||
sample = dataset[100]
|
||||
print(sample)
|
||||
# The sample is a dictionary of tensors
|
||||
# {
|
||||
# 'observation.state': tensor([...]),
|
||||
# 'action': tensor([...]),
|
||||
# 'observation.images.wrist_camera': tensor([C, H, W]),
|
||||
# 'timestamp': tensor(1.234),
|
||||
# ...
|
||||
# }
|
||||
delta_timestamps = {
|
||||
"observation.images.wrist_camera": [-0.2, -0.1, 0.0] # 0.2, and 0.1 seconds *before* any observation
|
||||
}
|
||||
dataset = LeRobotDataset(
|
||||
"lerobot/svla_so101_pickplace",
|
||||
delta_timestamps=delta_timestamps
|
||||
)
|
||||
|
||||
# Accessing an index now returns a stack of frames for the specified key
|
||||
sample = dataset[100]
|
||||
|
||||
# The image tensor will now have a time dimension
|
||||
# 'observation.images.wrist_camera' has shape [T, C, H, W], where T=3
|
||||
print(sample['observation.images.wrist_camera'].shape)
|
||||
|
||||
batch_size=16
|
||||
# wrap the dataset in a DataLoader to use process it batches for training purposes
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 3. Iterate over the DataLoader in a training loop
|
||||
num_epochs = 1
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch in data_loader:
|
||||
# 'batch' is a dictionary where each value is a batch of tensors.
|
||||
# For example, batch['action'] will have a shape of [32, action_dim].
|
||||
|
||||
# If using delta_timestamps, a batched image tensor might have a
|
||||
# shape of [32, T, C, H, W].
|
||||
|
||||
# Move data to the appropriate device (e.g., GPU)
|
||||
observations = batch['observation.state'].to(device)
|
||||
actions = batch['action'].to(device)
|
||||
images = batch['observation.images.wrist_camera'].to(device)
|
||||
|
||||
# Next do amazing_model.forward(batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
`LeRobotDataset` now also supports streaming mode.
|
||||
You can stream of data from a large dataset hosted on the Hugging Face Hub by just replacing the dataset definition with:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
# Streams frames from the Hugging Face Hub
|
||||
dataset = StreamingLeRobotDataset("lerobot/svla_so101_pickplace")
|
||||
```
|
||||
|
||||
Streaming datasets supports high-performance batch processing (ca. 80-100 it/s, varying on connectivity) and high levels of frames randomization: a key feature for behavioral cloning algorithms otherwise operating on highly non-i.i.d. data.
|
||||
-347
@@ -1,347 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Script for Pi0 pretrained policy inference and Hub upload."""
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Set seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload")
|
||||
parser.add_argument(
|
||||
"--source-model-id",
|
||||
type=str,
|
||||
default="pepijn223/pi0_libero_lerobot",
|
||||
help="Source model repository ID on Hugging Face Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on"
|
||||
)
|
||||
parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset")
|
||||
parser.add_argument(
|
||||
"--sample-idx", type=int, default=10, help="Sample index within episode to use for inference"
|
||||
)
|
||||
parser.add_argument("--private", action="store_true", help="Make the uploaded model private")
|
||||
parser.add_argument(
|
||||
"--commit-message", type=str, default=None, help="Custom commit message for the upload"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict):
|
||||
"""Recreate normalization layers with proper stats from the dataset."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
# Convert numpy stats to the format expected by normalization layers and remap keys
|
||||
stats = {}
|
||||
for dataset_key, stat_dict in dataset_meta.stats.items():
|
||||
# Use mapped key if available, otherwise use original key
|
||||
policy_key = key_mapping.get(dataset_key, dataset_key)
|
||||
|
||||
stats[policy_key] = {
|
||||
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||
for stat_type, stat_array in stat_dict.items()
|
||||
}
|
||||
|
||||
print(f"Available stats keys: {list(stats.keys())}")
|
||||
print(
|
||||
f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}"
|
||||
)
|
||||
|
||||
# Recreate normalization layers with proper stats
|
||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
# Replace the normalization layers on the policy
|
||||
policy.normalize_inputs = normalize_inputs
|
||||
policy.normalize_targets = normalize_targets
|
||||
policy.unnormalize_outputs = unnormalize_outputs
|
||||
|
||||
print("Normalization layers recreated with dataset stats.")
|
||||
|
||||
|
||||
def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset):
|
||||
"""Configure policy input and output features based on dataset metadata."""
|
||||
print(f"Dataset features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
# Create a proper mapping from dataset keys to policy keys
|
||||
dataset_to_policy_mapping = {}
|
||||
|
||||
# Handle images
|
||||
if "image" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["image"] = "observation.images.image"
|
||||
if "wrist_image" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["wrist_image"] = "observation.images.image2"
|
||||
|
||||
# Handle state
|
||||
if "state" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["state"] = "observation.state"
|
||||
|
||||
# Handle actions
|
||||
if "actions" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["actions"] = "action"
|
||||
|
||||
print(f"Key mapping: {dataset_to_policy_mapping}")
|
||||
|
||||
# Clear existing input features and reconfigure with proper mapping
|
||||
policy.config.input_features = {}
|
||||
policy.config.output_features = {}
|
||||
|
||||
# Map visual features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key in ["image", "wrist_image"]:
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
# Convert HWC to CHW format and resize
|
||||
shape = (3, 224, 224) # Pi0 expects CHW format
|
||||
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape)
|
||||
|
||||
# Map state features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key == "state":
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
shape = tuple(feature_info["shape"])
|
||||
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape)
|
||||
|
||||
# Map action features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key == "actions":
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
shape = tuple(feature_info["shape"])
|
||||
policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape)
|
||||
|
||||
print(f"Policy input_features: {list(policy.config.input_features.keys())}")
|
||||
print(f"Policy output_features: {list(policy.config.output_features.keys())}")
|
||||
print(f"Policy image_features: {list(policy.config.image_features.keys())}")
|
||||
print(f"Policy action_feature: {policy.config.action_feature}")
|
||||
|
||||
return dataset_to_policy_mapping
|
||||
|
||||
|
||||
def fix_buffer_naming(policy: PI0Policy):
|
||||
"""Fix buffer naming issues in the loaded policy state dict."""
|
||||
print("Fixing normalization buffer naming issues...")
|
||||
|
||||
state_dict = policy.state_dict()
|
||||
corrected_state_dict = {}
|
||||
fixes_applied = 0
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean
|
||||
if "buffer_observation_state_mean" in key:
|
||||
new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean")
|
||||
fixes_applied += 1
|
||||
print(f" Fixed: {key} -> {new_key}")
|
||||
elif "buffer_observation_state_std" in key:
|
||||
new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std")
|
||||
fixes_applied += 1
|
||||
print(f" Fixed: {key} -> {new_key}")
|
||||
# Remove image buffers that aren't expected (they cause conflicts)
|
||||
elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key:
|
||||
print(f" Removed unexpected buffer: {key}")
|
||||
continue # Skip this buffer
|
||||
|
||||
corrected_state_dict[new_key] = value
|
||||
|
||||
# Add missing action buffers with dummy values (will be replaced by dataset stats)
|
||||
missing_buffers = [
|
||||
"normalize_targets.buffer_action.mean",
|
||||
"normalize_targets.buffer_action.std",
|
||||
"unnormalize_outputs.buffer_action.mean",
|
||||
"unnormalize_outputs.buffer_action.std",
|
||||
]
|
||||
|
||||
for buffer_key in missing_buffers:
|
||||
if buffer_key not in corrected_state_dict:
|
||||
# Use dummy values - these will be overwritten by proper dataset stats later
|
||||
if "mean" in buffer_key:
|
||||
corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action
|
||||
else: # std
|
||||
corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action
|
||||
fixes_applied += 1
|
||||
print(f" Added missing buffer: {buffer_key}")
|
||||
|
||||
print(f"Applied {fixes_applied} buffer fixes")
|
||||
|
||||
# Load the corrected state dict back into the policy
|
||||
policy.load_state_dict(corrected_state_dict)
|
||||
return policy
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the Pi0 inference and upload."""
|
||||
args = parse_args()
|
||||
|
||||
# Load pretrained Pi0 model directly from Hugging Face Hub
|
||||
print(f"Loading pretrained Pi0 model from {args.source_model_id}...")
|
||||
|
||||
# Load with strict=False to allow missing/unexpected keys, then fix them manually
|
||||
policy = PI0Policy.from_pretrained(args.source_model_id, strict=False)
|
||||
policy = fix_buffer_naming(policy)
|
||||
policy.eval()
|
||||
policy.to(args.device)
|
||||
|
||||
# Load dataset and get a sample
|
||||
print(f"Loading dataset: {args.dataset_id}")
|
||||
dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode])
|
||||
meta: LeRobotDatasetMetadata = dataset.meta
|
||||
sample = dataset[args.sample_idx]
|
||||
|
||||
# Configure policy features
|
||||
key_mapping = configure_policy_features(policy, dataset)
|
||||
|
||||
# Inject normalization stats with proper key mapping
|
||||
_inject_normalization_stats(policy, meta, key_mapping)
|
||||
|
||||
# Prepare batch for PI0 (handle temporal dimensions)
|
||||
batch = {}
|
||||
|
||||
# Map dataset sample keys to policy keys
|
||||
reverse_mapping = {v: k for k, v in key_mapping.items()}
|
||||
|
||||
for policy_key in policy.config.input_features:
|
||||
# Find the corresponding dataset key
|
||||
dataset_key = reverse_mapping.get(policy_key, policy_key)
|
||||
|
||||
if dataset_key in sample:
|
||||
data = sample[dataset_key]
|
||||
|
||||
# Handle image data: convert from HWC to CHW and normalize
|
||||
if policy_key.startswith("observation.images."):
|
||||
if data.dim() == 3 and data.shape[-1] == 3: # HWC format
|
||||
data = data.permute(2, 0, 1) # Convert to CHW
|
||||
# Normalize to [0, 1] range if needed
|
||||
if data.dtype == torch.uint8:
|
||||
data = data.float() / 255.0
|
||||
# Resize to expected size if needed
|
||||
if data.shape[-2:] != (224, 224):
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
data = F.interpolate(
|
||||
data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
|
||||
# Remove temporal dimension if present
|
||||
if data.dim() > len(policy.config.input_features[policy_key].shape):
|
||||
data = data[0]
|
||||
|
||||
batch[policy_key] = data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Debug: print what's in the sample
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Batch keys prepared: {list(batch.keys())}")
|
||||
|
||||
# Pi0 requires task description - add a default if not available
|
||||
if "task" in sample:
|
||||
batch["task"] = [sample["task"]] # Keep as list of strings
|
||||
else:
|
||||
print("No task in sample, using default task description")
|
||||
batch["task"] = ["Complete the manipulation task"]
|
||||
|
||||
print(f"Task: {batch['task'][0]}")
|
||||
print(f"Final batch keys: {list(batch.keys())}")
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"Predicted action shape: {action.shape}")
|
||||
print(f"Predicted action: {action.tolist()}")
|
||||
|
||||
print("✅ Pi0 pretrained inference completed successfully!")
|
||||
|
||||
# Upload to Hugging Face Hub
|
||||
print(f"\n📤 Uploading model to Hugging Face Hub: {args.output_model_id}")
|
||||
|
||||
# Create commit message
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
commit_message = (
|
||||
args.commit_message
|
||||
or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}"
|
||||
)
|
||||
|
||||
# Update model configuration with dataset info
|
||||
policy.config.push_to_hub = True
|
||||
policy.config.repo_id = args.output_model_id
|
||||
policy.config.private = args.private
|
||||
|
||||
# Add metadata about the adaptation
|
||||
adaptation_info = {
|
||||
"source_model": args.source_model_id,
|
||||
"dataset_used": args.dataset_id,
|
||||
"adaptation_date": timestamp,
|
||||
"stats_injected": True,
|
||||
"key_mapping": key_mapping,
|
||||
"inference_test_passed": True,
|
||||
"sample_action_shape": list(action.shape),
|
||||
}
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
policy.push_to_hub(
|
||||
repo_id=args.output_model_id,
|
||||
private=args.private,
|
||||
commit_message=commit_message,
|
||||
create_pr=False,
|
||||
)
|
||||
|
||||
# Also save the adaptation info as a separate file
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create a temporary file with adaptation info
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(adaptation_info, f, indent=2)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
api.upload_file(
|
||||
path_or_fileobj=temp_path,
|
||||
path_in_repo="adaptation_info.json",
|
||||
repo_id=args.output_model_id,
|
||||
commit_message=f"Add adaptation metadata - {timestamp}",
|
||||
)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
print(f"✅ Model successfully uploaded to: https://huggingface.co/{args.output_model_id}")
|
||||
print("📋 Adaptation info:")
|
||||
for key, value in adaptation_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error uploading to Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
-704
@@ -1,704 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download # noqa: E402
|
||||
from safetensors.torch import load_file # noqa: E402
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
|
||||
from lerobot.configs.policies import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
RANDOM_SEED = 42 # Set to fixed value for reproducible results
|
||||
|
||||
|
||||
def set_all_seeds(seed=42):
|
||||
"""Set all random seeds for reproducible results."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
|
||||
|
||||
|
||||
# Set seeds at the start
|
||||
set_all_seeds(RANDOM_SEED)
|
||||
|
||||
config_model_path = "lerobot/pi0" # Use config from official model
|
||||
official_model_path = "lerobot/pi0" # Official model
|
||||
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
|
||||
device = "mps"
|
||||
|
||||
USE_FULL_TENSORS = True
|
||||
SAVE_TENSORS_TO_DISK = False
|
||||
|
||||
# Model transformation and upload settings
|
||||
SAVE_TRANSFORMED_MODEL = True # Set to True to save the transformed model
|
||||
UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
|
||||
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
|
||||
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
|
||||
|
||||
# Download and load the config manually to avoid draccus parsing issues
|
||||
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Remove the 'type' field that causes draccus issues
|
||||
if "type" in config_dict:
|
||||
config_dict.pop("type")
|
||||
print("Removed 'type' field from config")
|
||||
|
||||
# Create shared PI0Config
|
||||
print("Creating shared PI0Config...")
|
||||
shared_config = PI0Config(**config_dict)
|
||||
|
||||
|
||||
def load_policy_with_weights(
|
||||
model_path: str, config: PI0Config, model_name: str, apply_transformations: bool = False
|
||||
):
|
||||
"""Load a policy with specified weights but shared config."""
|
||||
print(f"\n=== Loading {model_name} from {model_path} ===")
|
||||
|
||||
# Set deterministic seed before creating the policy to ensure identical initialization
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
random.seed(RANDOM_SEED)
|
||||
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Download and load weights
|
||||
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors")
|
||||
print(f"Downloaded {model_name} weights to: {model_file}")
|
||||
|
||||
# Load state dict and apply transformations
|
||||
print(f"Investigating safetensors file: {model_file}")
|
||||
|
||||
# First, check what's in the metadata
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
all_keys_in_file = f.keys()
|
||||
print(f" Total keys in safetensors file: {len(list(all_keys_in_file))}")
|
||||
|
||||
# Check for embed_tokens in the file keys
|
||||
embed_keys_in_file = [k for k in f.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens keys in safetensors: {embed_keys_in_file}")
|
||||
|
||||
if metadata:
|
||||
print(f" Metadata exists: {list(metadata.keys()) if metadata else 'None'}")
|
||||
except Exception as e:
|
||||
print(f" Could not inspect safetensors file directly: {e}")
|
||||
|
||||
# Now load normally and see what we get
|
||||
state_dict = load_file(model_file)
|
||||
print(f" Keys loaded by load_file(): {len(state_dict)} keys")
|
||||
|
||||
# Check for embed_tokens in loaded state_dict
|
||||
loaded_embed_keys = [k for k in state_dict.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens keys in loaded state_dict: {loaded_embed_keys}")
|
||||
|
||||
# Check if we need to add "model." prefix (for custom models that don't have it)
|
||||
sample_key = next(iter(state_dict.keys()))
|
||||
if not sample_key.startswith("model."):
|
||||
print(f"Adding 'model.' prefix to all keys (detected format: {sample_key})")
|
||||
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
||||
|
||||
# IMPORTANT: Call PI0Policy._transform_state_dict_keys AFTER adding model. prefix
|
||||
# This ensures tied weights logic can find the correct key pattern
|
||||
transformed_state_dict = PI0Policy._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Apply specific PaliGemma key transformations only for custom models
|
||||
if apply_transformations:
|
||||
print("Applying custom model key transformations...")
|
||||
|
||||
# First, let's debug what keys we actually have
|
||||
all_keys = list(transformed_state_dict.keys())
|
||||
sample_keys = all_keys[:10]
|
||||
print(f"Sample keys to transform: {sample_keys}")
|
||||
|
||||
# Look for specific keys we need to transform and missing keys
|
||||
embed_tokens_keys = [k for k in all_keys if "embed_tokens" in k]
|
||||
embedding_keys = [k for k in all_keys if "embed" in k]
|
||||
lm_head_keys = [k for k in all_keys if "lm_head" in k]
|
||||
paligemma_keys = [
|
||||
k for k in all_keys if "paligemma_with_expert.paligemma" in k and "gemma_expert" not in k
|
||||
]
|
||||
language_model_keys = [k for k in all_keys if "language_model" in k]
|
||||
|
||||
print(f"Found embed_tokens keys: {embed_tokens_keys}")
|
||||
print(f"Found any embedding keys: {embedding_keys}")
|
||||
print(f"Found lm_head keys: {lm_head_keys}")
|
||||
print(
|
||||
f"Found paligemma keys (non-expert): {paligemma_keys[:5]}{'...' if len(paligemma_keys) > 5 else ''}"
|
||||
)
|
||||
print(
|
||||
f"Found language_model keys: {language_model_keys[:5]}{'...' if len(language_model_keys) > 5 else ''}"
|
||||
)
|
||||
print(f"Total keys in model: {len(all_keys)}")
|
||||
|
||||
# Check if the embed_tokens is in gemma_expert instead
|
||||
gemma_expert_embed = [k for k in all_keys if "gemma_expert" in k and "embed_tokens" in k]
|
||||
print(f"Found gemma_expert embed_tokens keys: {gemma_expert_embed}")
|
||||
|
||||
# Check what we're missing and what we actually have
|
||||
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if expected_embed_key not in all_keys:
|
||||
print(f" Missing expected embed_tokens key: {expected_embed_key}")
|
||||
|
||||
# Let's see what keys we actually have for debugging
|
||||
print("Debugging: Looking for any embedding-related keys...")
|
||||
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
|
||||
print(f"Keys containing 'embed': {all_embed_related}")
|
||||
|
||||
# Look for any keys that might contain embeddings
|
||||
potential_embed_keys = [
|
||||
k for k in all_keys if any(word in k for word in ["embed", "embedding", "token"])
|
||||
]
|
||||
print(f" Potential embedding keys: {potential_embed_keys}")
|
||||
|
||||
# Try to find a suitable replacement
|
||||
if gemma_expert_embed:
|
||||
print(f" Will try to copy from: {gemma_expert_embed[0]}")
|
||||
else:
|
||||
print(" No gemma_expert embed_tokens found either!")
|
||||
|
||||
# Check if there's an embed_tokens in the gemma_expert that we missed
|
||||
gemma_keys = [k for k in all_keys if "gemma_expert" in k]
|
||||
print(f" First 10 gemma_expert keys: {gemma_keys[:10]}")
|
||||
|
||||
# Check if there are any token-related keys in gemma_expert
|
||||
token_keys = [k for k in all_keys if "gemma_expert" in k and "token" in k.lower()]
|
||||
print(f" Gemma expert token-related keys: {token_keys}")
|
||||
|
||||
# Check for any keys that look like they might be embeddings
|
||||
possible_embeds = [
|
||||
k
|
||||
for k in all_keys
|
||||
if any(
|
||||
pattern in k.lower() for pattern in ["embed_token", "embedding", "wte", "word_embed"]
|
||||
)
|
||||
]
|
||||
print(f" Possible embedding alternatives: {possible_embeds}")
|
||||
|
||||
final_state_dict = {}
|
||||
transformation_count = 0
|
||||
|
||||
for key, value in transformed_state_dict.items():
|
||||
new_key = key
|
||||
original_key = key
|
||||
|
||||
# Transform vision tower keys: ADD .model between paligemma and vision_tower
|
||||
if "paligemma_with_expert.paligemma.vision_tower.vision_model" in new_key:
|
||||
new_key = new_key.replace(
|
||||
"paligemma_with_expert.paligemma.vision_tower.vision_model",
|
||||
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
|
||||
)
|
||||
print(f"Transformed vision key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
|
||||
elif "paligemma_with_expert.paligemma.multi_modal_projector" in new_key:
|
||||
new_key = new_key.replace(
|
||||
"paligemma_with_expert.paligemma.multi_modal_projector",
|
||||
"paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
)
|
||||
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# NO transformation needed for language_model keys - they're already correct!
|
||||
# The custom model already has: paligemma.model.language_model.* which is what we need
|
||||
|
||||
# NO transformation needed for lm_head - it should stay as paligemma.lm_head
|
||||
|
||||
final_state_dict[new_key] = value
|
||||
|
||||
print(f"Applied {transformation_count} key transformations")
|
||||
transformed_state_dict = final_state_dict
|
||||
else:
|
||||
print("No transformations applied (official model format)")
|
||||
|
||||
# Debug: show what keys the policy expects vs what we have
|
||||
policy_keys = set(policy.state_dict().keys())
|
||||
provided_keys = set(transformed_state_dict.keys())
|
||||
|
||||
missing_in_provided = policy_keys - provided_keys
|
||||
extra_in_provided = provided_keys - policy_keys
|
||||
|
||||
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
|
||||
if missing_in_provided:
|
||||
print(
|
||||
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
|
||||
)
|
||||
if extra_in_provided:
|
||||
print(
|
||||
f" Extra in provided: {list(extra_in_provided)[:5]}{'...' if len(extra_in_provided) > 5 else ''}"
|
||||
)
|
||||
|
||||
# Load the weights into the policy
|
||||
msg = policy.load_state_dict(transformed_state_dict, strict=True)
|
||||
print(
|
||||
f"{model_name} - Missing keys: {len(msg.missing_keys)}, Unexpected keys: {len(msg.unexpected_keys)}"
|
||||
)
|
||||
|
||||
if msg.missing_keys:
|
||||
print(
|
||||
f" Actually missing keys: {list(msg.missing_keys)[:3]}{'...' if len(msg.missing_keys) > 3 else ''}"
|
||||
)
|
||||
if msg.unexpected_keys:
|
||||
print(
|
||||
f" Actually unexpected keys: {list(msg.unexpected_keys)[:3]}{'...' if len(msg.unexpected_keys) > 3 else ''}"
|
||||
)
|
||||
|
||||
# Set deterministic mode and move to device
|
||||
policy = policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# Reset the policy to ensure identical internal state
|
||||
policy.reset()
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
# Load both models with shared config
|
||||
print("Loading both models with shared config...")
|
||||
official_policy = load_policy_with_weights(
|
||||
official_model_path, shared_config, "Official Model", apply_transformations=False
|
||||
)
|
||||
custom_policy = load_policy_with_weights(
|
||||
custom_model_path, shared_config, "Custom Model", apply_transformations=True
|
||||
)
|
||||
|
||||
print("\nBoth models loaded successfully!")
|
||||
print(f"Shared config: {shared_config}")
|
||||
print(f"Device: {device}")
|
||||
|
||||
|
||||
# Configure input features for both policies since they're not set by default in pretrained models
|
||||
def configure_policy_features(policy: PI0Policy):
|
||||
"""Configure input and output features for a policy."""
|
||||
policy.config.input_features[OBS_IMAGE] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Channel-first RGB image
|
||||
)
|
||||
|
||||
policy.config.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # 8-dimensional state vector
|
||||
)
|
||||
|
||||
policy.config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(8,), # 8-dimensional action vector
|
||||
)
|
||||
|
||||
# Add dummy normalization buffers to the policy (like openpi does with norm_stats)
|
||||
if hasattr(policy, "normalize_inputs"):
|
||||
# For observation.state (8-dim state vector)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_STATE.replace('.', '_')}_mean", torch.zeros(8, device=device)
|
||||
)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_STATE.replace('.', '_')}_std", torch.ones(8, device=device)
|
||||
)
|
||||
|
||||
# For observation.image (3x224x224 image)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_IMAGE.replace('.', '_')}_mean", torch.zeros(3, 224, 224, device=device)
|
||||
)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_IMAGE.replace('.', '_')}_std", torch.ones(3, 224, 224, device=device)
|
||||
)
|
||||
|
||||
|
||||
print("Configuring features for both policies...")
|
||||
configure_policy_features(official_policy)
|
||||
configure_policy_features(custom_policy)
|
||||
|
||||
# Verify that the models have identical parameters
|
||||
print("\n=== Model Parameter Comparison ===")
|
||||
official_params = dict(official_policy.named_parameters())
|
||||
custom_params = dict(custom_policy.named_parameters())
|
||||
|
||||
param_differences = []
|
||||
for name in official_params.keys():
|
||||
if name not in custom_params:
|
||||
param_differences.append(f"Missing parameter in custom model: {name}")
|
||||
else:
|
||||
diff = torch.abs(official_params[name] - custom_params[name]).max().item()
|
||||
if diff > 1e-8:
|
||||
param_differences.append(f"Parameter {name}: max difference = {diff:.2e}")
|
||||
|
||||
for name in custom_params.keys():
|
||||
if name not in official_params:
|
||||
param_differences.append(f"Extra parameter in custom model: {name}")
|
||||
|
||||
if param_differences:
|
||||
print("Parameter differences found:")
|
||||
for diff in param_differences[:10]: # Show first 10 differences
|
||||
print(f" {diff}")
|
||||
if len(param_differences) > 10:
|
||||
print(f" ... and {len(param_differences) - 10} more differences")
|
||||
else:
|
||||
print("All model parameters are identical!")
|
||||
|
||||
|
||||
# Get the raw models for direct comparison
|
||||
official_raw_model = official_policy.model
|
||||
custom_raw_model = custom_policy.model
|
||||
print("\n=== Model Details ===")
|
||||
print(f"Official raw model type: {type(official_raw_model)}")
|
||||
print(f"Custom raw model type: {type(custom_raw_model)}")
|
||||
print(f"Official model device: {next(official_raw_model.parameters()).device}")
|
||||
print(f"Custom model device: {next(custom_raw_model.parameters()).device}")
|
||||
|
||||
# Create lerobot-format input data (similar to DROID format from openpi example)
|
||||
example = {
|
||||
"joint_position": np.zeros(7, dtype=np.float32),
|
||||
"gripper_position": np.array([0.0], dtype=np.float32),
|
||||
"image": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||
"task": "pick up the object",
|
||||
}
|
||||
|
||||
print(f"\nProvided input keys: {list(example.keys())}")
|
||||
|
||||
print("\nPreparing inputs for direct model call...")
|
||||
|
||||
# Apply input transformation (similar to openpi's policy._input_transform)
|
||||
transformed_example = {}
|
||||
# Combine joint and gripper positions into state
|
||||
transformed_example[OBS_STATE] = np.concatenate([example["joint_position"], example["gripper_position"]])
|
||||
transformed_example[OBS_IMAGE] = example["image"]
|
||||
transformed_example["task"] = example["task"]
|
||||
|
||||
# Convert to PyTorch tensors and add batch dimension (as openpi example does)
|
||||
# Device is already defined above, use the official model device for consistency
|
||||
pytorch_inputs = {}
|
||||
for key, value in transformed_example.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_value = torch.from_numpy(value).to(device)
|
||||
# Add batch dimension
|
||||
if tensor_value.dim() > 0:
|
||||
tensor_value = tensor_value.unsqueeze(0)
|
||||
pytorch_inputs[key] = tensor_value
|
||||
elif isinstance(value, str):
|
||||
pytorch_inputs[key] = [value] # Convert to list format expected by policy
|
||||
else:
|
||||
pytorch_inputs[key] = value
|
||||
|
||||
# Convert image from HWC to CHW format for lerobot
|
||||
if OBS_IMAGE in pytorch_inputs:
|
||||
img = pytorch_inputs[OBS_IMAGE]
|
||||
if img.dim() == 4 and img.shape[-1] == 3: # BHWC -> BCHW
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
# Convert to float and normalize to [0, 1] range
|
||||
img = img.float() / 255.0
|
||||
pytorch_inputs[OBS_IMAGE] = img
|
||||
|
||||
print(f"Transformed input keys: {list(pytorch_inputs.keys())}")
|
||||
for key, value in pytorch_inputs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" {key}: {value.shape} {value.dtype}")
|
||||
else:
|
||||
print(f" {key}: {type(value)} - {value}")
|
||||
|
||||
# Reset both policies (clears the action queue)
|
||||
official_policy.reset()
|
||||
custom_policy.reset()
|
||||
|
||||
# Prepare inputs using the official policy (both models should have same preprocessing)
|
||||
print("Preparing inputs for both models...")
|
||||
images, img_masks = official_policy.prepare_images(pytorch_inputs)
|
||||
lang_tokens, lang_masks = official_policy.prepare_language(pytorch_inputs)
|
||||
state = official_policy.prepare_state(pytorch_inputs)
|
||||
|
||||
print("Prepared inputs:")
|
||||
print(f" Images: {len(images)} images")
|
||||
print(f" Language tokens shape: {lang_tokens.shape}")
|
||||
print(f" State shape: {state.shape}")
|
||||
for i, img in enumerate(images):
|
||||
print(f" Image {i} shape: {img.shape}")
|
||||
for i, mask in enumerate(img_masks):
|
||||
print(f" Image mask {i} shape: {mask.shape}")
|
||||
|
||||
# Compare both models with identical inputs
|
||||
print("\n🚀 Running MODEL COMPARISON...")
|
||||
|
||||
# Force torch.no_grad for consistent comparison
|
||||
with torch.no_grad():
|
||||
# Ensure reproducible noise generation for both models
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Generate synthetic noise and time for the forward call
|
||||
batch_size = 1
|
||||
actions_shape = (
|
||||
batch_size,
|
||||
official_raw_model.config.n_action_steps,
|
||||
official_raw_model.config.max_action_dim,
|
||||
)
|
||||
|
||||
# Generate noise and time using direct PyTorch operations instead of model methods
|
||||
# This avoids any potential model-specific randomness
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=actions_shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Generate time using the same distribution as PI0FlowMatching.sample_time
|
||||
torch.manual_seed(RANDOM_SEED) # Reset for consistent time
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((batch_size,)).to(device=device, dtype=torch.float32)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
|
||||
print("\n=== Generated Inputs ===")
|
||||
print(f" Action shape: {actions_shape}")
|
||||
print(f" Noise shape: {noise.shape}")
|
||||
print(f" Time value: {time.item():.6f}")
|
||||
print(f" Noise sample (first 5 values): {noise.flatten()[:5].tolist()}")
|
||||
|
||||
# Create dummy actions for forward pass (required for training forward)
|
||||
dummy_actions = torch.zeros(actions_shape, dtype=torch.float32, device=device)
|
||||
|
||||
print("\n=== Running Forward Passes ===")
|
||||
|
||||
print("Running with model_addition_debugger_context for detailed analysis...")
|
||||
# Create separate debug paths for each model
|
||||
official_debug_path = os.path.join(debug_path, "official_model")
|
||||
custom_debug_path = os.path.join(debug_path, "custom_model")
|
||||
os.makedirs(official_debug_path, exist_ok=True)
|
||||
os.makedirs(custom_debug_path, exist_ok=True)
|
||||
# Set deterministic mode for forward pass
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run official model with debugger
|
||||
print("Running Official Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
official_raw_model,
|
||||
debug_path=official_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
official_loss = official_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
# Reset seed before second forward pass to ensure any internal randomness is identical
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run custom model with debugger
|
||||
print("Running Custom Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
custom_raw_model,
|
||||
debug_path=custom_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
custom_loss = custom_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
|
||||
print(f"Official model debug outputs saved to: {official_debug_path}")
|
||||
print(f"Custom model debug outputs saved to: {custom_debug_path}")
|
||||
|
||||
print("\n=== Output Comparison ===")
|
||||
print(f"Official model loss shape: {official_loss.shape}")
|
||||
print(f"Custom model loss shape: {custom_loss.shape}")
|
||||
|
||||
# Compare outputs
|
||||
loss_diff = torch.abs(official_loss - custom_loss)
|
||||
|
||||
print("\n=== Detailed Comparison ===")
|
||||
print("Loss difference stats:")
|
||||
print(f" Mean absolute difference: {loss_diff.mean().item():.8f}")
|
||||
print(f" Max absolute difference: {loss_diff.max().item():.8f}")
|
||||
print(f" Min absolute difference: {loss_diff.min().item():.8f}")
|
||||
print(f" Standard deviation of difference: {loss_diff.std().item():.8f}")
|
||||
|
||||
# Show some actual values for comparison
|
||||
print("\nSample output values:")
|
||||
print(f" Official model (first 5): {official_loss.flatten()[:5].tolist()}")
|
||||
print(f" Custom model (first 5): {custom_loss.flatten()[:5].tolist()}")
|
||||
print(f" Difference (first 5): {loss_diff.flatten()[:5].tolist()}")
|
||||
|
||||
# Determine if models are equivalent
|
||||
are_equivalent = loss_diff.max().item() < 1e-6
|
||||
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
|
||||
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
|
||||
|
||||
print(f"\nDetailed debugging outputs saved to: {debug_path}")
|
||||
# Save comparison results
|
||||
comparison_results = {
|
||||
"official_loss_stats": {
|
||||
"shape": list(official_loss.shape),
|
||||
"mean": official_loss.mean().item(),
|
||||
"std": official_loss.std().item(),
|
||||
"min": official_loss.min().item(),
|
||||
"max": official_loss.max().item(),
|
||||
},
|
||||
"custom_loss_stats": {
|
||||
"shape": list(custom_loss.shape),
|
||||
"mean": custom_loss.mean().item(),
|
||||
"std": custom_loss.std().item(),
|
||||
"min": custom_loss.min().item(),
|
||||
"max": custom_loss.max().item(),
|
||||
},
|
||||
"difference_stats": {
|
||||
"mean_abs_diff": loss_diff.mean().item(),
|
||||
"max_abs_diff": loss_diff.max().item(),
|
||||
"min_abs_diff": loss_diff.min().item(),
|
||||
"std_diff": loss_diff.std().item(),
|
||||
"are_equivalent": are_equivalent,
|
||||
},
|
||||
}
|
||||
|
||||
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
|
||||
with open(comparison_file, "w") as f:
|
||||
json.dump(comparison_results, f, indent=2)
|
||||
print(f" Comparison results saved to: {comparison_file}")
|
||||
|
||||
# Save and upload transformed model if requested
|
||||
if SAVE_TRANSFORMED_MODEL:
|
||||
print("\nSaving Transformed Model...")
|
||||
if are_equivalent:
|
||||
print("Models are equivalent - proceeding with transformation and upload")
|
||||
else:
|
||||
print("Models are NOT equivalent, but proceeding with upload anyway")
|
||||
print(f" Max difference: {loss_diff.max().item():.2e}")
|
||||
print(" This might be useful for debugging or partial transformations")
|
||||
|
||||
# Create timestamp for README
|
||||
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
try:
|
||||
# Use the already working custom policy as the base for transformation
|
||||
print("Using already working custom policy as base for transformed model...")
|
||||
|
||||
# Deep copy the custom policy to create the transformed version
|
||||
from copy import deepcopy
|
||||
|
||||
transformed_policy = deepcopy(custom_policy)
|
||||
|
||||
print("Custom policy copied successfully - no additional configuration needed")
|
||||
|
||||
# Save locally first
|
||||
local_save_path = "./transformed_pi0_model"
|
||||
print(f"Saving transformed model locally to: {local_save_path}")
|
||||
transformed_policy.save_pretrained(local_save_path, safe_serialization=True)
|
||||
|
||||
# Save the tokenizer as well (required for complete model)
|
||||
transformed_policy.language_tokenizer.save_pretrained(local_save_path)
|
||||
|
||||
# Create a README with transformation details
|
||||
readme_content = f"""
|
||||
# PI0 Model - LeRobot Compatible Format
|
||||
|
||||
This model is a transformed version of `{custom_model_path}` with key names corrected to match the official LeRobot PI0 format.
|
||||
|
||||
## Transformation Applied
|
||||
|
||||
The original model had a different key naming convention. This model applies the following transformations:
|
||||
|
||||
1. **Model prefix**: Added `model.` prefix to all parameter keys
|
||||
2. **Tied weights**: Applied PI0Policy's built-in tied weights logic to create `embed_tokens.weight` from `lm_head.weight`
|
||||
3. **Key structure**: Applied standard PI0 key transformations for compatibility
|
||||
|
||||
## Verification
|
||||
|
||||
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
|
||||
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Load the model
|
||||
policy = PI0Policy.from_pretrained("{TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
# Use for inference
|
||||
action = policy.select_action(observation_batch)
|
||||
```
|
||||
|
||||
## Original Model
|
||||
|
||||
- **Source**: {custom_model_path}
|
||||
- **Verified Against**: {official_model_path}
|
||||
|
||||
## Technical Details
|
||||
|
||||
- **Total Parameters**: {sum(p.numel() for p in transformed_policy.parameters()):,}
|
||||
- **Model Type**: PI0FlowMatching with PaliGemma + Expert Gemma
|
||||
- **Configuration**: Matches official PI0 configuration
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(local_save_path, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(readme_content.strip())
|
||||
|
||||
print(f"Model saved locally to: {local_save_path}")
|
||||
|
||||
# Upload to HuggingFace Hub if requested
|
||||
if UPLOAD_TO_HUB:
|
||||
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
transformed_policy.push_to_hub(
|
||||
repo_id=TRANSFORMED_MODEL_NAME,
|
||||
commit_message=COMMIT_MESSAGE,
|
||||
private=False, # Make it public
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
|
||||
print("You can now use this model directly without any transformations!")
|
||||
print("\n Usage:")
|
||||
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
|
||||
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
|
||||
|
||||
except Exception as upload_error:
|
||||
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
|
||||
print(f"You can manually upload the model from: {local_save_path}")
|
||||
print(" Or set UPLOAD_TO_HUB = False and upload later")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(f"Error saving transformed model: {str(e)}")
|
||||
print("Full traceback:")
|
||||
traceback.print_exc()
|
||||
print("The model transformation logic works, but saving failed")
|
||||
|
||||
else:
|
||||
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
|
||||
@@ -13,22 +13,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and
|
||||
upload it to your own repository. It will:
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Download the dataset from any source repository
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Update codebase_version in `info.json` to the latest version
|
||||
- Create proper version tags
|
||||
- Push the converted dataset to your specified destination repository
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
|
||||
--source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
|
||||
--dest-repo-id=your-username/libero_spatial_converted \
|
||||
--episodes=0,1,2,3,4
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -39,8 +37,8 @@ import logging
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info
|
||||
from lerobot.datasets.v21.convert_stats import convert_stats
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
@@ -56,133 +54,48 @@ class SuppressWarnings:
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
source_repo_id: str,
|
||||
dest_repo_id: str | None = None,
|
||||
episodes: str | None = None,
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
force_cache_sync: bool = True,
|
||||
):
|
||||
"""
|
||||
Download a dataset from source_repo_id, convert it, and upload to dest_repo_id.
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
Args:
|
||||
source_repo_id: Source repository to download from
|
||||
dest_repo_id: Destination repository to upload to (defaults to source_repo_id)
|
||||
episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3")
|
||||
branch: Branch to upload to
|
||||
num_workers: Number of workers for stats computation
|
||||
force_cache_sync: Whether to force cache synchronization
|
||||
"""
|
||||
if dest_repo_id is None:
|
||||
dest_repo_id = source_repo_id
|
||||
|
||||
# Parse episodes list if provided
|
||||
episode_list = None
|
||||
if episodes:
|
||||
try:
|
||||
episode_list = [int(ep.strip()) for ep in episodes.split(",")]
|
||||
print(f"Loading episodes: {episode_list}")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'"
|
||||
) from e
|
||||
|
||||
print(f"Downloading dataset from: {source_repo_id}")
|
||||
|
||||
# Try to load the dataset with different approaches to handle versioning issues
|
||||
dataset = None
|
||||
load_attempts = [
|
||||
{"revision": None}, # Try latest first
|
||||
{"revision": V20}, # Try v2.0
|
||||
{"revision": "main"}, # Try main branch
|
||||
]
|
||||
|
||||
for attempt in load_attempts:
|
||||
try:
|
||||
print(f"Attempting to load with revision: {attempt['revision']}")
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(
|
||||
source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt
|
||||
)
|
||||
print("Successfully loaded dataset!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Failed with revision {attempt['revision']}: {e}")
|
||||
continue
|
||||
|
||||
if dataset is None:
|
||||
raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision")
|
||||
|
||||
# Clean up old stats if present
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
print("Removed existing episodes_stats.jsonl")
|
||||
|
||||
print("Converting stats to new format...")
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
# Update dataset info
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
print(f"Updated codebase_version to {CODEBASE_VERSION}")
|
||||
|
||||
# Change repo_id for destination if different
|
||||
if dest_repo_id != source_repo_id:
|
||||
print(f"Changing repository from {source_repo_id} to {dest_repo_id}")
|
||||
dataset.repo_id = dest_repo_id
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
print(f"Pushing converted dataset to: {dest_repo_id}")
|
||||
dataset.push_to_hub(branch=branch, tag_version=False)
|
||||
|
||||
# Clean up old stats.json file locally and on hub
|
||||
if (dataset.root / STATS_PATH).is_file():
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
print("Removed local stats.json file")
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
if hub_api.file_exists(
|
||||
repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
print("Removed stats.json from hub")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remove stats.json from hub: {e}")
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
# Create version tag
|
||||
try:
|
||||
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create tag: {e}")
|
||||
|
||||
print(f"✅ Successfully converted and uploaded dataset to {dest_repo_id}")
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download, convert, and re-upload LeRobot datasets with proper versioning"
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source-repo-id",
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.",
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
@@ -196,22 +109,6 @@ if __name__ == "__main__":
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache-sync",
|
||||
action="store_true",
|
||||
help="Skip forcing cache synchronization (faster but may use cached data)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert args to match function signature
|
||||
convert_args = {
|
||||
"source_repo_id": args.source_repo_id,
|
||||
"dest_repo_id": args.dest_repo_id,
|
||||
"episodes": args.episodes,
|
||||
"branch": args.branch,
|
||||
"num_workers": args.num_workers,
|
||||
"force_cache_sync": not args.no_cache_sync,
|
||||
}
|
||||
|
||||
convert_dataset(**convert_args)
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
Reference in New Issue
Block a user