mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 23:29:52 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c75df5c3b9 | |||
| e2740fe555 |
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python
|
||||
"""Simple script to check buffer naming in the transformed model."""
|
||||
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Load the model with strict=False to see what buffers we have
|
||||
print("Loading model...")
|
||||
policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False)
|
||||
|
||||
# Check what buffer keys exist
|
||||
state_dict = policy.state_dict()
|
||||
buffer_keys = [k for k in state_dict.keys() if "buffer" in k]
|
||||
normalize_keys = [k for k in state_dict.keys() if "normalize" in k]
|
||||
|
||||
print("\nAll buffer keys:")
|
||||
for key in buffer_keys:
|
||||
print(f" {key}")
|
||||
|
||||
print("\nAll normalize keys:")
|
||||
for key in normalize_keys:
|
||||
print(f" {key}")
|
||||
|
||||
print("\nAll keys (first 20):")
|
||||
for i, key in enumerate(state_dict.keys()):
|
||||
if i < 20:
|
||||
print(f" {key}")
|
||||
@@ -1,47 +0,0 @@
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# config_follower_right = ViperXConfig(
|
||||
# port="/dev/tty.usbserial-FT891KBG",
|
||||
# id="viperx_right",
|
||||
# )
|
||||
|
||||
# follower_right = ViperX(config_follower_right)
|
||||
# follower_right.connect(calibrate=False)
|
||||
# follower_right.calibrate()
|
||||
# follower_right.disconnect()
|
||||
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# config_leader_right = WidowXConfig(
|
||||
# port="/dev/tty.usbserial-FT89FM77",
|
||||
# id="widowx_right",
|
||||
# )
|
||||
|
||||
# leader_right = WidowX(config_leader_right)
|
||||
# leader_right.connect(calibrate=False)
|
||||
# leader_right.calibrate()
|
||||
# leader_right.disconnect()
|
||||
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# config_follower_left = ViperXConfig(
|
||||
# port="/dev/tty.usbserial-FT89FM09",
|
||||
# id="viperx_left",
|
||||
# )
|
||||
|
||||
# follower_left = ViperX(config_follower_left)
|
||||
# follower_left.connect(calibrate=False)
|
||||
# follower_left.calibrate()
|
||||
# follower_left.disconnect()
|
||||
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# config_leader_left = WidowXConfig(
|
||||
# port="/dev/tty.usbserial-FT891KPN",
|
||||
# id="widowx_left",
|
||||
# )
|
||||
|
||||
# leader_left = WidowX(config_leader_left)
|
||||
# leader_left.connect(calibrate=False)
|
||||
# leader_left.calibrate()
|
||||
# leader_left.disconnect()
|
||||
@@ -1,172 +0,0 @@
|
||||
"""
|
||||
ALOHA Bimanual Recording Script
|
||||
|
||||
This script records episodes using ALOHA dual-arm system (ViperX followers + WidowX leaders).
|
||||
|
||||
Usage:
|
||||
1. New dataset: Set RESUME = False
|
||||
2. Resume/append: Set RESUME = True (will continue from existing episodes)
|
||||
|
||||
The script will:
|
||||
- Record NUM_EPISODES new episodes
|
||||
- Show progress with total episode count
|
||||
- Push dataset to HuggingFace Hub when complete
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.aloha import Aloha, AlohaConfig
|
||||
from lerobot.teleoperators.aloha_teleop import AlohaTeleop, AlohaTeleopConfig
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
# Recording configuration
|
||||
NUM_EPISODES = 0
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 200
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "First put the Hugging Face t shirt with both arms in the box, then place the hat with the right arm in the box."
|
||||
REPO_ID = "pepijn223/aloha_box_2"
|
||||
RESUME = True # Set to True to resume/append to existing dataset
|
||||
|
||||
# Create camera configuration
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS),
|
||||
"wrist_right": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=FPS),
|
||||
"wrist_left": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
# ALOHA Robot Configuration (dual ViperX followers)
|
||||
aloha_robot_config = AlohaConfig(
|
||||
id="aloha",
|
||||
left_arm_port="/dev/tty.usbserial-FT89FM09",
|
||||
right_arm_port="/dev/tty.usbserial-FT891KBG",
|
||||
left_arm_max_relative_target=20.0,
|
||||
right_arm_max_relative_target=20.0,
|
||||
left_arm_use_degrees=True,
|
||||
right_arm_use_degrees=True,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
# ALOHA Teleoperator Configuration (dual WidowX leaders)
|
||||
aloha_teleop_config = AlohaTeleopConfig(
|
||||
id="aloha_teleop",
|
||||
left_arm_port="/dev/tty.usbserial-FT891KPN",
|
||||
right_arm_port="/dev/tty.usbserial-FT89FM77",
|
||||
left_arm_gripper_motor="xl430-w250",
|
||||
right_arm_gripper_motor="xc430-w150",
|
||||
left_arm_use_degrees=True,
|
||||
right_arm_use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = Aloha(aloha_robot_config)
|
||||
teleop = AlohaTeleop(aloha_teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create or resume the dataset
|
||||
if RESUME:
|
||||
print(f"Resuming existing dataset: {REPO_ID}")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=REPO_ID,
|
||||
root=None, # Use default root
|
||||
)
|
||||
|
||||
# Start image writer for cameras
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=0, # Use threads only
|
||||
num_threads=4 * len(robot.cameras), # 4 threads per camera
|
||||
)
|
||||
|
||||
# Sanity check compatibility
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, FPS, dataset_features)
|
||||
print(f"Resumed dataset with {dataset.num_episodes} existing episodes")
|
||||
else:
|
||||
print(f"Creating new dataset: {REPO_ID}")
|
||||
# Sanity check dataset name
|
||||
sanity_check_dataset_name(REPO_ID, None)
|
||||
|
||||
# Create new dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4 * len(robot.cameras), # 4 threads per camera
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="aloha_recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
episode_idx = 0
|
||||
total_episodes_to_record = NUM_EPISODES
|
||||
existing_episodes = dataset.num_episodes if RESUME else 0
|
||||
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
current_episode = existing_episodes + episode_idx + 1
|
||||
log_say(f"Recording episode {current_episode} (batch: {episode_idx + 1}/{NUM_EPISODES})")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
|
||||
# Summary
|
||||
final_episodes = dataset.num_episodes
|
||||
log_say(f"Dataset now contains {final_episodes} episodes total")
|
||||
|
||||
# Push to hub
|
||||
dataset.push_to_hub()
|
||||
log_say(f"Dataset '{REPO_ID}' pushed to HuggingFace Hub")
|
||||
@@ -1,93 +0,0 @@
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.viperx import ViperX, ViperXConfig
|
||||
from lerobot.teleoperators.widowx import WidowX, WidowXConfig
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"wrist_right": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"wrist_left": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
config_follower_right = ViperXConfig(
|
||||
port="/dev/tty.usbserial-FT891KBG",
|
||||
id="viperx_right",
|
||||
max_relative_target=10.0, # increased from default 5.0 to 10.0
|
||||
use_degrees=True,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
config_leader_right = WidowXConfig(
|
||||
port="/dev/tty.usbserial-FT89FM77",
|
||||
id="widowx_right",
|
||||
gripper_motor="xc430-w150",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
config_follower_left = ViperXConfig(
|
||||
port="/dev/tty.usbserial-FT89FM09",
|
||||
id="viperx_left",
|
||||
max_relative_target=10.0, # increased from default 5.0 to 10.0
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
config_leader_left = WidowXConfig(
|
||||
port="/dev/tty.usbserial-FT891KPN",
|
||||
id="widowx_left",
|
||||
gripper_motor="xl430-w250",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
_init_rerun(session_name="teleop")
|
||||
|
||||
follower_right = ViperX(config_follower_right)
|
||||
follower_right.connect()
|
||||
|
||||
leader_right = WidowX(config_leader_right)
|
||||
leader_right.connect()
|
||||
|
||||
follower_left = ViperX(config_follower_left)
|
||||
follower_left.connect()
|
||||
|
||||
leader_left = WidowX(config_leader_left)
|
||||
leader_left.connect()
|
||||
|
||||
|
||||
while True:
|
||||
act_right = leader_right.get_action()
|
||||
obs_right = follower_right.get_observation()
|
||||
|
||||
act_left = leader_left.get_action()
|
||||
obs_left = follower_left.get_observation()
|
||||
|
||||
print("=" * 60)
|
||||
print("ACTION (Leader Right):")
|
||||
for key, value in act_right.items():
|
||||
if key.endswith(".pos"):
|
||||
print(f" {key:20}: {value:8.3f}")
|
||||
|
||||
print("\nOBSERVATION (Follower Right):")
|
||||
for key, value in obs_right.items():
|
||||
if key.endswith(".pos"):
|
||||
print(f" {key:20}: {value:8.3f}")
|
||||
|
||||
print("=" * 60)
|
||||
print("ACTION (Leader Left):")
|
||||
for key, value in act_left.items():
|
||||
if key.endswith(".pos"):
|
||||
print(f" {key:20}: {value:8.3f}")
|
||||
|
||||
print("\nOBSERVATION (Follower Left):")
|
||||
for key, value in obs_left.items():
|
||||
if key.endswith(".pos"):
|
||||
print(f" {key:20}: {value:8.3f}")
|
||||
print("=" * 60)
|
||||
|
||||
log_rerun_data({**obs_right, **obs_left}, {**act_right, **act_left})
|
||||
|
||||
follower_right.send_action(act_right)
|
||||
follower_left.send_action(act_left)
|
||||
|
||||
time.sleep(0.02)
|
||||
+347
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python
|
||||
"""Script for Pi0 pretrained policy inference and Hub upload."""
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Set seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload")
|
||||
parser.add_argument(
|
||||
"--source-model-id",
|
||||
type=str,
|
||||
default="pepijn223/pi0_libero_lerobot",
|
||||
help="Source model repository ID on Hugging Face Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on"
|
||||
)
|
||||
parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset")
|
||||
parser.add_argument(
|
||||
"--sample-idx", type=int, default=10, help="Sample index within episode to use for inference"
|
||||
)
|
||||
parser.add_argument("--private", action="store_true", help="Make the uploaded model private")
|
||||
parser.add_argument(
|
||||
"--commit-message", type=str, default=None, help="Custom commit message for the upload"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict):
|
||||
"""Recreate normalization layers with proper stats from the dataset."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
# Convert numpy stats to the format expected by normalization layers and remap keys
|
||||
stats = {}
|
||||
for dataset_key, stat_dict in dataset_meta.stats.items():
|
||||
# Use mapped key if available, otherwise use original key
|
||||
policy_key = key_mapping.get(dataset_key, dataset_key)
|
||||
|
||||
stats[policy_key] = {
|
||||
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||
for stat_type, stat_array in stat_dict.items()
|
||||
}
|
||||
|
||||
print(f"Available stats keys: {list(stats.keys())}")
|
||||
print(
|
||||
f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}"
|
||||
)
|
||||
|
||||
# Recreate normalization layers with proper stats
|
||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
# Replace the normalization layers on the policy
|
||||
policy.normalize_inputs = normalize_inputs
|
||||
policy.normalize_targets = normalize_targets
|
||||
policy.unnormalize_outputs = unnormalize_outputs
|
||||
|
||||
print("Normalization layers recreated with dataset stats.")
|
||||
|
||||
|
||||
def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset):
|
||||
"""Configure policy input and output features based on dataset metadata."""
|
||||
print(f"Dataset features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
# Create a proper mapping from dataset keys to policy keys
|
||||
dataset_to_policy_mapping = {}
|
||||
|
||||
# Handle images
|
||||
if "image" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["image"] = "observation.images.image"
|
||||
if "wrist_image" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["wrist_image"] = "observation.images.image2"
|
||||
|
||||
# Handle state
|
||||
if "state" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["state"] = "observation.state"
|
||||
|
||||
# Handle actions
|
||||
if "actions" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["actions"] = "action"
|
||||
|
||||
print(f"Key mapping: {dataset_to_policy_mapping}")
|
||||
|
||||
# Clear existing input features and reconfigure with proper mapping
|
||||
policy.config.input_features = {}
|
||||
policy.config.output_features = {}
|
||||
|
||||
# Map visual features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key in ["image", "wrist_image"]:
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
# Convert HWC to CHW format and resize
|
||||
shape = (3, 224, 224) # Pi0 expects CHW format
|
||||
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape)
|
||||
|
||||
# Map state features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key == "state":
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
shape = tuple(feature_info["shape"])
|
||||
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape)
|
||||
|
||||
# Map action features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key == "actions":
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
shape = tuple(feature_info["shape"])
|
||||
policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape)
|
||||
|
||||
print(f"Policy input_features: {list(policy.config.input_features.keys())}")
|
||||
print(f"Policy output_features: {list(policy.config.output_features.keys())}")
|
||||
print(f"Policy image_features: {list(policy.config.image_features.keys())}")
|
||||
print(f"Policy action_feature: {policy.config.action_feature}")
|
||||
|
||||
return dataset_to_policy_mapping
|
||||
|
||||
|
||||
def fix_buffer_naming(policy: PI0Policy):
|
||||
"""Fix buffer naming issues in the loaded policy state dict."""
|
||||
print("Fixing normalization buffer naming issues...")
|
||||
|
||||
state_dict = policy.state_dict()
|
||||
corrected_state_dict = {}
|
||||
fixes_applied = 0
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean
|
||||
if "buffer_observation_state_mean" in key:
|
||||
new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean")
|
||||
fixes_applied += 1
|
||||
print(f" Fixed: {key} -> {new_key}")
|
||||
elif "buffer_observation_state_std" in key:
|
||||
new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std")
|
||||
fixes_applied += 1
|
||||
print(f" Fixed: {key} -> {new_key}")
|
||||
# Remove image buffers that aren't expected (they cause conflicts)
|
||||
elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key:
|
||||
print(f" Removed unexpected buffer: {key}")
|
||||
continue # Skip this buffer
|
||||
|
||||
corrected_state_dict[new_key] = value
|
||||
|
||||
# Add missing action buffers with dummy values (will be replaced by dataset stats)
|
||||
missing_buffers = [
|
||||
"normalize_targets.buffer_action.mean",
|
||||
"normalize_targets.buffer_action.std",
|
||||
"unnormalize_outputs.buffer_action.mean",
|
||||
"unnormalize_outputs.buffer_action.std",
|
||||
]
|
||||
|
||||
for buffer_key in missing_buffers:
|
||||
if buffer_key not in corrected_state_dict:
|
||||
# Use dummy values - these will be overwritten by proper dataset stats later
|
||||
if "mean" in buffer_key:
|
||||
corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action
|
||||
else: # std
|
||||
corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action
|
||||
fixes_applied += 1
|
||||
print(f" Added missing buffer: {buffer_key}")
|
||||
|
||||
print(f"Applied {fixes_applied} buffer fixes")
|
||||
|
||||
# Load the corrected state dict back into the policy
|
||||
policy.load_state_dict(corrected_state_dict)
|
||||
return policy
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the Pi0 inference and upload."""
|
||||
args = parse_args()
|
||||
|
||||
# Load pretrained Pi0 model directly from Hugging Face Hub
|
||||
print(f"Loading pretrained Pi0 model from {args.source_model_id}...")
|
||||
|
||||
# Load with strict=False to allow missing/unexpected keys, then fix them manually
|
||||
policy = PI0Policy.from_pretrained(args.source_model_id, strict=False)
|
||||
policy = fix_buffer_naming(policy)
|
||||
policy.eval()
|
||||
policy.to(args.device)
|
||||
|
||||
# Load dataset and get a sample
|
||||
print(f"Loading dataset: {args.dataset_id}")
|
||||
dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode])
|
||||
meta: LeRobotDatasetMetadata = dataset.meta
|
||||
sample = dataset[args.sample_idx]
|
||||
|
||||
# Configure policy features
|
||||
key_mapping = configure_policy_features(policy, dataset)
|
||||
|
||||
# Inject normalization stats with proper key mapping
|
||||
_inject_normalization_stats(policy, meta, key_mapping)
|
||||
|
||||
# Prepare batch for PI0 (handle temporal dimensions)
|
||||
batch = {}
|
||||
|
||||
# Map dataset sample keys to policy keys
|
||||
reverse_mapping = {v: k for k, v in key_mapping.items()}
|
||||
|
||||
for policy_key in policy.config.input_features:
|
||||
# Find the corresponding dataset key
|
||||
dataset_key = reverse_mapping.get(policy_key, policy_key)
|
||||
|
||||
if dataset_key in sample:
|
||||
data = sample[dataset_key]
|
||||
|
||||
# Handle image data: convert from HWC to CHW and normalize
|
||||
if policy_key.startswith("observation.images."):
|
||||
if data.dim() == 3 and data.shape[-1] == 3: # HWC format
|
||||
data = data.permute(2, 0, 1) # Convert to CHW
|
||||
# Normalize to [0, 1] range if needed
|
||||
if data.dtype == torch.uint8:
|
||||
data = data.float() / 255.0
|
||||
# Resize to expected size if needed
|
||||
if data.shape[-2:] != (224, 224):
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
data = F.interpolate(
|
||||
data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
|
||||
# Remove temporal dimension if present
|
||||
if data.dim() > len(policy.config.input_features[policy_key].shape):
|
||||
data = data[0]
|
||||
|
||||
batch[policy_key] = data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Debug: print what's in the sample
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Batch keys prepared: {list(batch.keys())}")
|
||||
|
||||
# Pi0 requires task description - add a default if not available
|
||||
if "task" in sample:
|
||||
batch["task"] = [sample["task"]] # Keep as list of strings
|
||||
else:
|
||||
print("No task in sample, using default task description")
|
||||
batch["task"] = ["Complete the manipulation task"]
|
||||
|
||||
print(f"Task: {batch['task'][0]}")
|
||||
print(f"Final batch keys: {list(batch.keys())}")
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"Predicted action shape: {action.shape}")
|
||||
print(f"Predicted action: {action.tolist()}")
|
||||
|
||||
print("✅ Pi0 pretrained inference completed successfully!")
|
||||
|
||||
# Upload to Hugging Face Hub
|
||||
print(f"\n📤 Uploading model to Hugging Face Hub: {args.output_model_id}")
|
||||
|
||||
# Create commit message
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
commit_message = (
|
||||
args.commit_message
|
||||
or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}"
|
||||
)
|
||||
|
||||
# Update model configuration with dataset info
|
||||
policy.config.push_to_hub = True
|
||||
policy.config.repo_id = args.output_model_id
|
||||
policy.config.private = args.private
|
||||
|
||||
# Add metadata about the adaptation
|
||||
adaptation_info = {
|
||||
"source_model": args.source_model_id,
|
||||
"dataset_used": args.dataset_id,
|
||||
"adaptation_date": timestamp,
|
||||
"stats_injected": True,
|
||||
"key_mapping": key_mapping,
|
||||
"inference_test_passed": True,
|
||||
"sample_action_shape": list(action.shape),
|
||||
}
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
policy.push_to_hub(
|
||||
repo_id=args.output_model_id,
|
||||
private=args.private,
|
||||
commit_message=commit_message,
|
||||
create_pr=False,
|
||||
)
|
||||
|
||||
# Also save the adaptation info as a separate file
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create a temporary file with adaptation info
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(adaptation_info, f, indent=2)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
api.upload_file(
|
||||
path_or_fileobj=temp_path,
|
||||
path_in_repo="adaptation_info.json",
|
||||
repo_id=args.output_model_id,
|
||||
commit_message=f"Add adaptation metadata - {timestamp}",
|
||||
)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
print(f"✅ Model successfully uploaded to: https://huggingface.co/{args.output_model_id}")
|
||||
print("📋 Adaptation info:")
|
||||
for key, value in adaptation_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error uploading to Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+704
@@ -0,0 +1,704 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download # noqa: E402
|
||||
from safetensors.torch import load_file # noqa: E402
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
|
||||
from lerobot.configs.policies import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
RANDOM_SEED = 42 # Set to fixed value for reproducible results
|
||||
|
||||
|
||||
def set_all_seeds(seed=42):
|
||||
"""Set all random seeds for reproducible results."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
|
||||
|
||||
|
||||
# Set seeds at the start
|
||||
set_all_seeds(RANDOM_SEED)
|
||||
|
||||
config_model_path = "lerobot/pi0" # Use config from official model
|
||||
official_model_path = "lerobot/pi0" # Official model
|
||||
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
|
||||
device = "mps"
|
||||
|
||||
USE_FULL_TENSORS = True
|
||||
SAVE_TENSORS_TO_DISK = False
|
||||
|
||||
# Model transformation and upload settings
|
||||
SAVE_TRANSFORMED_MODEL = True # Set to True to save the transformed model
|
||||
UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
|
||||
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
|
||||
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
|
||||
|
||||
# Download and load the config manually to avoid draccus parsing issues
|
||||
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Remove the 'type' field that causes draccus issues
|
||||
if "type" in config_dict:
|
||||
config_dict.pop("type")
|
||||
print("Removed 'type' field from config")
|
||||
|
||||
# Create shared PI0Config
|
||||
print("Creating shared PI0Config...")
|
||||
shared_config = PI0Config(**config_dict)
|
||||
|
||||
|
||||
def load_policy_with_weights(
|
||||
model_path: str, config: PI0Config, model_name: str, apply_transformations: bool = False
|
||||
):
|
||||
"""Load a policy with specified weights but shared config."""
|
||||
print(f"\n=== Loading {model_name} from {model_path} ===")
|
||||
|
||||
# Set deterministic seed before creating the policy to ensure identical initialization
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
random.seed(RANDOM_SEED)
|
||||
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Download and load weights
|
||||
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors")
|
||||
print(f"Downloaded {model_name} weights to: {model_file}")
|
||||
|
||||
# Load state dict and apply transformations
|
||||
print(f"Investigating safetensors file: {model_file}")
|
||||
|
||||
# First, check what's in the metadata
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
all_keys_in_file = f.keys()
|
||||
print(f" Total keys in safetensors file: {len(list(all_keys_in_file))}")
|
||||
|
||||
# Check for embed_tokens in the file keys
|
||||
embed_keys_in_file = [k for k in f.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens keys in safetensors: {embed_keys_in_file}")
|
||||
|
||||
if metadata:
|
||||
print(f" Metadata exists: {list(metadata.keys()) if metadata else 'None'}")
|
||||
except Exception as e:
|
||||
print(f" Could not inspect safetensors file directly: {e}")
|
||||
|
||||
# Now load normally and see what we get
|
||||
state_dict = load_file(model_file)
|
||||
print(f" Keys loaded by load_file(): {len(state_dict)} keys")
|
||||
|
||||
# Check for embed_tokens in loaded state_dict
|
||||
loaded_embed_keys = [k for k in state_dict.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens keys in loaded state_dict: {loaded_embed_keys}")
|
||||
|
||||
# Check if we need to add "model." prefix (for custom models that don't have it)
|
||||
sample_key = next(iter(state_dict.keys()))
|
||||
if not sample_key.startswith("model."):
|
||||
print(f"Adding 'model.' prefix to all keys (detected format: {sample_key})")
|
||||
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
||||
|
||||
# IMPORTANT: Call PI0Policy._transform_state_dict_keys AFTER adding model. prefix
|
||||
# This ensures tied weights logic can find the correct key pattern
|
||||
transformed_state_dict = PI0Policy._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Apply specific PaliGemma key transformations only for custom models
|
||||
if apply_transformations:
|
||||
print("Applying custom model key transformations...")
|
||||
|
||||
# First, let's debug what keys we actually have
|
||||
all_keys = list(transformed_state_dict.keys())
|
||||
sample_keys = all_keys[:10]
|
||||
print(f"Sample keys to transform: {sample_keys}")
|
||||
|
||||
# Look for specific keys we need to transform and missing keys
|
||||
embed_tokens_keys = [k for k in all_keys if "embed_tokens" in k]
|
||||
embedding_keys = [k for k in all_keys if "embed" in k]
|
||||
lm_head_keys = [k for k in all_keys if "lm_head" in k]
|
||||
paligemma_keys = [
|
||||
k for k in all_keys if "paligemma_with_expert.paligemma" in k and "gemma_expert" not in k
|
||||
]
|
||||
language_model_keys = [k for k in all_keys if "language_model" in k]
|
||||
|
||||
print(f"Found embed_tokens keys: {embed_tokens_keys}")
|
||||
print(f"Found any embedding keys: {embedding_keys}")
|
||||
print(f"Found lm_head keys: {lm_head_keys}")
|
||||
print(
|
||||
f"Found paligemma keys (non-expert): {paligemma_keys[:5]}{'...' if len(paligemma_keys) > 5 else ''}"
|
||||
)
|
||||
print(
|
||||
f"Found language_model keys: {language_model_keys[:5]}{'...' if len(language_model_keys) > 5 else ''}"
|
||||
)
|
||||
print(f"Total keys in model: {len(all_keys)}")
|
||||
|
||||
# Check if the embed_tokens is in gemma_expert instead
|
||||
gemma_expert_embed = [k for k in all_keys if "gemma_expert" in k and "embed_tokens" in k]
|
||||
print(f"Found gemma_expert embed_tokens keys: {gemma_expert_embed}")
|
||||
|
||||
# Check what we're missing and what we actually have
|
||||
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if expected_embed_key not in all_keys:
|
||||
print(f" Missing expected embed_tokens key: {expected_embed_key}")
|
||||
|
||||
# Let's see what keys we actually have for debugging
|
||||
print("Debugging: Looking for any embedding-related keys...")
|
||||
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
|
||||
print(f"Keys containing 'embed': {all_embed_related}")
|
||||
|
||||
# Look for any keys that might contain embeddings
|
||||
potential_embed_keys = [
|
||||
k for k in all_keys if any(word in k for word in ["embed", "embedding", "token"])
|
||||
]
|
||||
print(f" Potential embedding keys: {potential_embed_keys}")
|
||||
|
||||
# Try to find a suitable replacement
|
||||
if gemma_expert_embed:
|
||||
print(f" Will try to copy from: {gemma_expert_embed[0]}")
|
||||
else:
|
||||
print(" No gemma_expert embed_tokens found either!")
|
||||
|
||||
# Check if there's an embed_tokens in the gemma_expert that we missed
|
||||
gemma_keys = [k for k in all_keys if "gemma_expert" in k]
|
||||
print(f" First 10 gemma_expert keys: {gemma_keys[:10]}")
|
||||
|
||||
# Check if there are any token-related keys in gemma_expert
|
||||
token_keys = [k for k in all_keys if "gemma_expert" in k and "token" in k.lower()]
|
||||
print(f" Gemma expert token-related keys: {token_keys}")
|
||||
|
||||
# Check for any keys that look like they might be embeddings
|
||||
possible_embeds = [
|
||||
k
|
||||
for k in all_keys
|
||||
if any(
|
||||
pattern in k.lower() for pattern in ["embed_token", "embedding", "wte", "word_embed"]
|
||||
)
|
||||
]
|
||||
print(f" Possible embedding alternatives: {possible_embeds}")
|
||||
|
||||
final_state_dict = {}
|
||||
transformation_count = 0
|
||||
|
||||
for key, value in transformed_state_dict.items():
|
||||
new_key = key
|
||||
original_key = key
|
||||
|
||||
# Transform vision tower keys: ADD .model between paligemma and vision_tower
|
||||
if "paligemma_with_expert.paligemma.vision_tower.vision_model" in new_key:
|
||||
new_key = new_key.replace(
|
||||
"paligemma_with_expert.paligemma.vision_tower.vision_model",
|
||||
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
|
||||
)
|
||||
print(f"Transformed vision key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
|
||||
elif "paligemma_with_expert.paligemma.multi_modal_projector" in new_key:
|
||||
new_key = new_key.replace(
|
||||
"paligemma_with_expert.paligemma.multi_modal_projector",
|
||||
"paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
)
|
||||
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# NO transformation needed for language_model keys - they're already correct!
|
||||
# The custom model already has: paligemma.model.language_model.* which is what we need
|
||||
|
||||
# NO transformation needed for lm_head - it should stay as paligemma.lm_head
|
||||
|
||||
final_state_dict[new_key] = value
|
||||
|
||||
print(f"Applied {transformation_count} key transformations")
|
||||
transformed_state_dict = final_state_dict
|
||||
else:
|
||||
print("No transformations applied (official model format)")
|
||||
|
||||
# Debug: show what keys the policy expects vs what we have
|
||||
policy_keys = set(policy.state_dict().keys())
|
||||
provided_keys = set(transformed_state_dict.keys())
|
||||
|
||||
missing_in_provided = policy_keys - provided_keys
|
||||
extra_in_provided = provided_keys - policy_keys
|
||||
|
||||
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
|
||||
if missing_in_provided:
|
||||
print(
|
||||
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
|
||||
)
|
||||
if extra_in_provided:
|
||||
print(
|
||||
f" Extra in provided: {list(extra_in_provided)[:5]}{'...' if len(extra_in_provided) > 5 else ''}"
|
||||
)
|
||||
|
||||
# Load the weights into the policy
|
||||
msg = policy.load_state_dict(transformed_state_dict, strict=True)
|
||||
print(
|
||||
f"{model_name} - Missing keys: {len(msg.missing_keys)}, Unexpected keys: {len(msg.unexpected_keys)}"
|
||||
)
|
||||
|
||||
if msg.missing_keys:
|
||||
print(
|
||||
f" Actually missing keys: {list(msg.missing_keys)[:3]}{'...' if len(msg.missing_keys) > 3 else ''}"
|
||||
)
|
||||
if msg.unexpected_keys:
|
||||
print(
|
||||
f" Actually unexpected keys: {list(msg.unexpected_keys)[:3]}{'...' if len(msg.unexpected_keys) > 3 else ''}"
|
||||
)
|
||||
|
||||
# Set deterministic mode and move to device
|
||||
policy = policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# Reset the policy to ensure identical internal state
|
||||
policy.reset()
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
# Load both models with shared config
|
||||
print("Loading both models with shared config...")
|
||||
official_policy = load_policy_with_weights(
|
||||
official_model_path, shared_config, "Official Model", apply_transformations=False
|
||||
)
|
||||
custom_policy = load_policy_with_weights(
|
||||
custom_model_path, shared_config, "Custom Model", apply_transformations=True
|
||||
)
|
||||
|
||||
print("\nBoth models loaded successfully!")
|
||||
print(f"Shared config: {shared_config}")
|
||||
print(f"Device: {device}")
|
||||
|
||||
|
||||
# Configure input features for both policies since they're not set by default in pretrained models
|
||||
def configure_policy_features(policy: PI0Policy):
|
||||
"""Configure input and output features for a policy."""
|
||||
policy.config.input_features[OBS_IMAGE] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Channel-first RGB image
|
||||
)
|
||||
|
||||
policy.config.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # 8-dimensional state vector
|
||||
)
|
||||
|
||||
policy.config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(8,), # 8-dimensional action vector
|
||||
)
|
||||
|
||||
# Add dummy normalization buffers to the policy (like openpi does with norm_stats)
|
||||
if hasattr(policy, "normalize_inputs"):
|
||||
# For observation.state (8-dim state vector)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_STATE.replace('.', '_')}_mean", torch.zeros(8, device=device)
|
||||
)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_STATE.replace('.', '_')}_std", torch.ones(8, device=device)
|
||||
)
|
||||
|
||||
# For observation.image (3x224x224 image)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_IMAGE.replace('.', '_')}_mean", torch.zeros(3, 224, 224, device=device)
|
||||
)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_IMAGE.replace('.', '_')}_std", torch.ones(3, 224, 224, device=device)
|
||||
)
|
||||
|
||||
|
||||
print("Configuring features for both policies...")
|
||||
configure_policy_features(official_policy)
|
||||
configure_policy_features(custom_policy)
|
||||
|
||||
# Verify that the models have identical parameters
|
||||
print("\n=== Model Parameter Comparison ===")
|
||||
official_params = dict(official_policy.named_parameters())
|
||||
custom_params = dict(custom_policy.named_parameters())
|
||||
|
||||
param_differences = []
|
||||
for name in official_params.keys():
|
||||
if name not in custom_params:
|
||||
param_differences.append(f"Missing parameter in custom model: {name}")
|
||||
else:
|
||||
diff = torch.abs(official_params[name] - custom_params[name]).max().item()
|
||||
if diff > 1e-8:
|
||||
param_differences.append(f"Parameter {name}: max difference = {diff:.2e}")
|
||||
|
||||
for name in custom_params.keys():
|
||||
if name not in official_params:
|
||||
param_differences.append(f"Extra parameter in custom model: {name}")
|
||||
|
||||
if param_differences:
|
||||
print("Parameter differences found:")
|
||||
for diff in param_differences[:10]: # Show first 10 differences
|
||||
print(f" {diff}")
|
||||
if len(param_differences) > 10:
|
||||
print(f" ... and {len(param_differences) - 10} more differences")
|
||||
else:
|
||||
print("All model parameters are identical!")
|
||||
|
||||
|
||||
# Get the raw models for direct comparison
|
||||
official_raw_model = official_policy.model
|
||||
custom_raw_model = custom_policy.model
|
||||
print("\n=== Model Details ===")
|
||||
print(f"Official raw model type: {type(official_raw_model)}")
|
||||
print(f"Custom raw model type: {type(custom_raw_model)}")
|
||||
print(f"Official model device: {next(official_raw_model.parameters()).device}")
|
||||
print(f"Custom model device: {next(custom_raw_model.parameters()).device}")
|
||||
|
||||
# Create lerobot-format input data (similar to DROID format from openpi example)
|
||||
example = {
|
||||
"joint_position": np.zeros(7, dtype=np.float32),
|
||||
"gripper_position": np.array([0.0], dtype=np.float32),
|
||||
"image": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||
"task": "pick up the object",
|
||||
}
|
||||
|
||||
print(f"\nProvided input keys: {list(example.keys())}")
|
||||
|
||||
print("\nPreparing inputs for direct model call...")
|
||||
|
||||
# Apply input transformation (similar to openpi's policy._input_transform)
|
||||
transformed_example = {}
|
||||
# Combine joint and gripper positions into state
|
||||
transformed_example[OBS_STATE] = np.concatenate([example["joint_position"], example["gripper_position"]])
|
||||
transformed_example[OBS_IMAGE] = example["image"]
|
||||
transformed_example["task"] = example["task"]
|
||||
|
||||
# Convert to PyTorch tensors and add batch dimension (as openpi example does)
|
||||
# Device is already defined above, use the official model device for consistency
|
||||
pytorch_inputs = {}
|
||||
for key, value in transformed_example.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_value = torch.from_numpy(value).to(device)
|
||||
# Add batch dimension
|
||||
if tensor_value.dim() > 0:
|
||||
tensor_value = tensor_value.unsqueeze(0)
|
||||
pytorch_inputs[key] = tensor_value
|
||||
elif isinstance(value, str):
|
||||
pytorch_inputs[key] = [value] # Convert to list format expected by policy
|
||||
else:
|
||||
pytorch_inputs[key] = value
|
||||
|
||||
# Convert image from HWC to CHW format for lerobot
|
||||
if OBS_IMAGE in pytorch_inputs:
|
||||
img = pytorch_inputs[OBS_IMAGE]
|
||||
if img.dim() == 4 and img.shape[-1] == 3: # BHWC -> BCHW
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
# Convert to float and normalize to [0, 1] range
|
||||
img = img.float() / 255.0
|
||||
pytorch_inputs[OBS_IMAGE] = img
|
||||
|
||||
print(f"Transformed input keys: {list(pytorch_inputs.keys())}")
|
||||
for key, value in pytorch_inputs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" {key}: {value.shape} {value.dtype}")
|
||||
else:
|
||||
print(f" {key}: {type(value)} - {value}")
|
||||
|
||||
# Reset both policies (clears the action queue)
|
||||
official_policy.reset()
|
||||
custom_policy.reset()
|
||||
|
||||
# Prepare inputs using the official policy (both models should have same preprocessing)
|
||||
print("Preparing inputs for both models...")
|
||||
images, img_masks = official_policy.prepare_images(pytorch_inputs)
|
||||
lang_tokens, lang_masks = official_policy.prepare_language(pytorch_inputs)
|
||||
state = official_policy.prepare_state(pytorch_inputs)
|
||||
|
||||
print("Prepared inputs:")
|
||||
print(f" Images: {len(images)} images")
|
||||
print(f" Language tokens shape: {lang_tokens.shape}")
|
||||
print(f" State shape: {state.shape}")
|
||||
for i, img in enumerate(images):
|
||||
print(f" Image {i} shape: {img.shape}")
|
||||
for i, mask in enumerate(img_masks):
|
||||
print(f" Image mask {i} shape: {mask.shape}")
|
||||
|
||||
# Compare both models with identical inputs
|
||||
print("\n🚀 Running MODEL COMPARISON...")
|
||||
|
||||
# Force torch.no_grad for consistent comparison
|
||||
with torch.no_grad():
|
||||
# Ensure reproducible noise generation for both models
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Generate synthetic noise and time for the forward call
|
||||
batch_size = 1
|
||||
actions_shape = (
|
||||
batch_size,
|
||||
official_raw_model.config.n_action_steps,
|
||||
official_raw_model.config.max_action_dim,
|
||||
)
|
||||
|
||||
# Generate noise and time using direct PyTorch operations instead of model methods
|
||||
# This avoids any potential model-specific randomness
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=actions_shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Generate time using the same distribution as PI0FlowMatching.sample_time
|
||||
torch.manual_seed(RANDOM_SEED) # Reset for consistent time
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((batch_size,)).to(device=device, dtype=torch.float32)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
|
||||
print("\n=== Generated Inputs ===")
|
||||
print(f" Action shape: {actions_shape}")
|
||||
print(f" Noise shape: {noise.shape}")
|
||||
print(f" Time value: {time.item():.6f}")
|
||||
print(f" Noise sample (first 5 values): {noise.flatten()[:5].tolist()}")
|
||||
|
||||
# Create dummy actions for forward pass (required for training forward)
|
||||
dummy_actions = torch.zeros(actions_shape, dtype=torch.float32, device=device)
|
||||
|
||||
print("\n=== Running Forward Passes ===")
|
||||
|
||||
print("Running with model_addition_debugger_context for detailed analysis...")
|
||||
# Create separate debug paths for each model
|
||||
official_debug_path = os.path.join(debug_path, "official_model")
|
||||
custom_debug_path = os.path.join(debug_path, "custom_model")
|
||||
os.makedirs(official_debug_path, exist_ok=True)
|
||||
os.makedirs(custom_debug_path, exist_ok=True)
|
||||
# Set deterministic mode for forward pass
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run official model with debugger
|
||||
print("Running Official Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
official_raw_model,
|
||||
debug_path=official_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
official_loss = official_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
# Reset seed before second forward pass to ensure any internal randomness is identical
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run custom model with debugger
|
||||
print("Running Custom Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
custom_raw_model,
|
||||
debug_path=custom_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
custom_loss = custom_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
|
||||
print(f"Official model debug outputs saved to: {official_debug_path}")
|
||||
print(f"Custom model debug outputs saved to: {custom_debug_path}")
|
||||
|
||||
print("\n=== Output Comparison ===")
|
||||
print(f"Official model loss shape: {official_loss.shape}")
|
||||
print(f"Custom model loss shape: {custom_loss.shape}")
|
||||
|
||||
# Compare outputs
|
||||
loss_diff = torch.abs(official_loss - custom_loss)
|
||||
|
||||
print("\n=== Detailed Comparison ===")
|
||||
print("Loss difference stats:")
|
||||
print(f" Mean absolute difference: {loss_diff.mean().item():.8f}")
|
||||
print(f" Max absolute difference: {loss_diff.max().item():.8f}")
|
||||
print(f" Min absolute difference: {loss_diff.min().item():.8f}")
|
||||
print(f" Standard deviation of difference: {loss_diff.std().item():.8f}")
|
||||
|
||||
# Show some actual values for comparison
|
||||
print("\nSample output values:")
|
||||
print(f" Official model (first 5): {official_loss.flatten()[:5].tolist()}")
|
||||
print(f" Custom model (first 5): {custom_loss.flatten()[:5].tolist()}")
|
||||
print(f" Difference (first 5): {loss_diff.flatten()[:5].tolist()}")
|
||||
|
||||
# Determine if models are equivalent
|
||||
are_equivalent = loss_diff.max().item() < 1e-6
|
||||
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
|
||||
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
|
||||
|
||||
print(f"\nDetailed debugging outputs saved to: {debug_path}")
|
||||
# Save comparison results
|
||||
comparison_results = {
|
||||
"official_loss_stats": {
|
||||
"shape": list(official_loss.shape),
|
||||
"mean": official_loss.mean().item(),
|
||||
"std": official_loss.std().item(),
|
||||
"min": official_loss.min().item(),
|
||||
"max": official_loss.max().item(),
|
||||
},
|
||||
"custom_loss_stats": {
|
||||
"shape": list(custom_loss.shape),
|
||||
"mean": custom_loss.mean().item(),
|
||||
"std": custom_loss.std().item(),
|
||||
"min": custom_loss.min().item(),
|
||||
"max": custom_loss.max().item(),
|
||||
},
|
||||
"difference_stats": {
|
||||
"mean_abs_diff": loss_diff.mean().item(),
|
||||
"max_abs_diff": loss_diff.max().item(),
|
||||
"min_abs_diff": loss_diff.min().item(),
|
||||
"std_diff": loss_diff.std().item(),
|
||||
"are_equivalent": are_equivalent,
|
||||
},
|
||||
}
|
||||
|
||||
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
|
||||
with open(comparison_file, "w") as f:
|
||||
json.dump(comparison_results, f, indent=2)
|
||||
print(f" Comparison results saved to: {comparison_file}")
|
||||
|
||||
# Save and upload transformed model if requested
|
||||
if SAVE_TRANSFORMED_MODEL:
|
||||
print("\nSaving Transformed Model...")
|
||||
if are_equivalent:
|
||||
print("Models are equivalent - proceeding with transformation and upload")
|
||||
else:
|
||||
print("Models are NOT equivalent, but proceeding with upload anyway")
|
||||
print(f" Max difference: {loss_diff.max().item():.2e}")
|
||||
print(" This might be useful for debugging or partial transformations")
|
||||
|
||||
# Create timestamp for README
|
||||
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
try:
|
||||
# Use the already working custom policy as the base for transformation
|
||||
print("Using already working custom policy as base for transformed model...")
|
||||
|
||||
# Deep copy the custom policy to create the transformed version
|
||||
from copy import deepcopy
|
||||
|
||||
transformed_policy = deepcopy(custom_policy)
|
||||
|
||||
print("Custom policy copied successfully - no additional configuration needed")
|
||||
|
||||
# Save locally first
|
||||
local_save_path = "./transformed_pi0_model"
|
||||
print(f"Saving transformed model locally to: {local_save_path}")
|
||||
transformed_policy.save_pretrained(local_save_path, safe_serialization=True)
|
||||
|
||||
# Save the tokenizer as well (required for complete model)
|
||||
transformed_policy.language_tokenizer.save_pretrained(local_save_path)
|
||||
|
||||
# Create a README with transformation details
|
||||
readme_content = f"""
|
||||
# PI0 Model - LeRobot Compatible Format
|
||||
|
||||
This model is a transformed version of `{custom_model_path}` with key names corrected to match the official LeRobot PI0 format.
|
||||
|
||||
## Transformation Applied
|
||||
|
||||
The original model had a different key naming convention. This model applies the following transformations:
|
||||
|
||||
1. **Model prefix**: Added `model.` prefix to all parameter keys
|
||||
2. **Tied weights**: Applied PI0Policy's built-in tied weights logic to create `embed_tokens.weight` from `lm_head.weight`
|
||||
3. **Key structure**: Applied standard PI0 key transformations for compatibility
|
||||
|
||||
## Verification
|
||||
|
||||
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
|
||||
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Load the model
|
||||
policy = PI0Policy.from_pretrained("{TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
# Use for inference
|
||||
action = policy.select_action(observation_batch)
|
||||
```
|
||||
|
||||
## Original Model
|
||||
|
||||
- **Source**: {custom_model_path}
|
||||
- **Verified Against**: {official_model_path}
|
||||
|
||||
## Technical Details
|
||||
|
||||
- **Total Parameters**: {sum(p.numel() for p in transformed_policy.parameters()):,}
|
||||
- **Model Type**: PI0FlowMatching with PaliGemma + Expert Gemma
|
||||
- **Configuration**: Matches official PI0 configuration
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(local_save_path, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(readme_content.strip())
|
||||
|
||||
print(f"Model saved locally to: {local_save_path}")
|
||||
|
||||
# Upload to HuggingFace Hub if requested
|
||||
if UPLOAD_TO_HUB:
|
||||
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
transformed_policy.push_to_hub(
|
||||
repo_id=TRANSFORMED_MODEL_NAME,
|
||||
commit_message=COMMIT_MESSAGE,
|
||||
private=False, # Make it public
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
|
||||
print("You can now use this model directly without any transformations!")
|
||||
print("\n Usage:")
|
||||
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
|
||||
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
|
||||
|
||||
except Exception as upload_error:
|
||||
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
|
||||
print(f"You can manually upload the model from: {local_save_path}")
|
||||
print(" Or set UPLOAD_TO_HUB = False and upload later")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(f"Error saving transformed model: {str(e)}")
|
||||
print("Full traceback:")
|
||||
traceback.print_exc()
|
||||
print("The model transformation logic works, but saving failed")
|
||||
|
||||
else:
|
||||
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
|
||||
@@ -13,20 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and
|
||||
upload it to your own repository. It will:
|
||||
|
||||
- Download the dataset from any source repository
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
- Update codebase_version in `info.json` to the latest version
|
||||
- Create proper version tags
|
||||
- Push the converted dataset to your specified destination repository
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
--source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
|
||||
--dest-repo-id=your-username/libero_spatial_converted \
|
||||
--episodes=0,1,2,3,4
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -37,8 +39,8 @@ import logging
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info
|
||||
from lerobot.datasets.v21.convert_stats import convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
@@ -54,48 +56,133 @@ class SuppressWarnings:
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
source_repo_id: str,
|
||||
dest_repo_id: str | None = None,
|
||||
episodes: str | None = None,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
force_cache_sync: bool = True,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
"""
|
||||
Download a dataset from source_repo_id, convert it, and upload to dest_repo_id.
|
||||
|
||||
Args:
|
||||
source_repo_id: Source repository to download from
|
||||
dest_repo_id: Destination repository to upload to (defaults to source_repo_id)
|
||||
episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3")
|
||||
branch: Branch to upload to
|
||||
num_workers: Number of workers for stats computation
|
||||
force_cache_sync: Whether to force cache synchronization
|
||||
"""
|
||||
if dest_repo_id is None:
|
||||
dest_repo_id = source_repo_id
|
||||
|
||||
# Parse episodes list if provided
|
||||
episode_list = None
|
||||
if episodes:
|
||||
try:
|
||||
episode_list = [int(ep.strip()) for ep in episodes.split(",")]
|
||||
print(f"Loading episodes: {episode_list}")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'"
|
||||
) from e
|
||||
|
||||
print(f"Downloading dataset from: {source_repo_id}")
|
||||
|
||||
# Try to load the dataset with different approaches to handle versioning issues
|
||||
dataset = None
|
||||
load_attempts = [
|
||||
{"revision": None}, # Try latest first
|
||||
{"revision": V20}, # Try v2.0
|
||||
{"revision": "main"}, # Try main branch
|
||||
]
|
||||
|
||||
for attempt in load_attempts:
|
||||
try:
|
||||
print(f"Attempting to load with revision: {attempt['revision']}")
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(
|
||||
source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt
|
||||
)
|
||||
print("Successfully loaded dataset!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Failed with revision {attempt['revision']}: {e}")
|
||||
continue
|
||||
|
||||
if dataset is None:
|
||||
raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision")
|
||||
|
||||
# Clean up old stats if present
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
print("Removed existing episodes_stats.jsonl")
|
||||
|
||||
print("Converting stats to new format...")
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
# Update dataset info
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
print(f"Updated codebase_version to {CODEBASE_VERSION}")
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
# Change repo_id for destination if different
|
||||
if dest_repo_id != source_repo_id:
|
||||
print(f"Changing repository from {source_repo_id} to {dest_repo_id}")
|
||||
dataset.repo_id = dest_repo_id
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
print(f"Pushing converted dataset to: {dest_repo_id}")
|
||||
dataset.push_to_hub(branch=branch, tag_version=False)
|
||||
|
||||
# Clean up old stats.json file locally and on hub
|
||||
if (dataset.root / STATS_PATH).is_file():
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
print("Removed local stats.json file")
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
try:
|
||||
if hub_api.file_exists(
|
||||
repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
print("Removed stats.json from hub")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remove stats.json from hub: {e}")
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
# Create version tag
|
||||
try:
|
||||
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create tag: {e}")
|
||||
|
||||
print(f"✅ Successfully converted and uploaded dataset to {dest_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download, convert, and re-upload LeRobot datasets with proper versioning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
"--source-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
@@ -109,6 +196,22 @@ if __name__ == "__main__":
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache-sync",
|
||||
action="store_true",
|
||||
help="Skip forcing cache synchronization (faster but may use cached data)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
# Convert args to match function signature
|
||||
convert_args = {
|
||||
"source_repo_id": args.source_repo_id,
|
||||
"dest_repo_id": args.dest_repo_id,
|
||||
"episodes": args.episodes,
|
||||
"branch": args.branch,
|
||||
"num_workers": args.num_workers,
|
||||
"force_cache_sync": not args.no_cache_sync,
|
||||
}
|
||||
|
||||
convert_dataset(**convert_args)
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .aloha import Aloha
|
||||
from .config_aloha import AlohaConfig
|
||||
|
||||
__all__ = ["Aloha", "AlohaConfig"]
|
||||
@@ -1,161 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.robots.viperx import ViperX
|
||||
from lerobot.robots.viperx.config_viperx import ViperXConfig
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_aloha import AlohaConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Aloha(Robot):
|
||||
"""
|
||||
ALOHA Bimanual Robot System using dual ViperX follower arms.
|
||||
Based on the ALOHA (A Low-cost Open-source Hardware System for Bimanual Teleoperation) design.
|
||||
"""
|
||||
|
||||
config_class = AlohaConfig
|
||||
name = "aloha"
|
||||
|
||||
def __init__(self, config: AlohaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = ViperXConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_port,
|
||||
max_relative_target=config.left_arm_max_relative_target,
|
||||
use_degrees=config.left_arm_use_degrees,
|
||||
cameras={},
|
||||
)
|
||||
|
||||
right_arm_config = ViperXConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_port,
|
||||
max_relative_target=config.right_arm_max_relative_target,
|
||||
use_degrees=config.right_arm_use_degrees,
|
||||
cameras={},
|
||||
)
|
||||
|
||||
self.left_arm = ViperX(left_arm_config)
|
||||
self.right_arm = ViperX(right_arm_config)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | {
|
||||
f"right_{motor}.pos": float for motor in self.right_arm.bus.motors
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self.left_arm.bus.is_connected
|
||||
and self.right_arm.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
)
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
}
|
||||
# Remove "right_" prefix
|
||||
right_action = {
|
||||
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
|
||||
}
|
||||
|
||||
send_action_left = self.left_arm.send_action(left_action)
|
||||
send_action_right = self.right_arm.send_action(right_action)
|
||||
|
||||
# Add prefixes back
|
||||
prefixed_send_action_left = {f"left_{key}": value for key, value in send_action_left.items()}
|
||||
prefixed_send_action_right = {f"right_{key}": value for key, value in send_action_right.items()}
|
||||
|
||||
return {**prefixed_send_action_left, **prefixed_send_action_right}
|
||||
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaConfig(RobotConfig):
|
||||
left_arm_port: str
|
||||
right_arm_port: str
|
||||
|
||||
# Optional parameters for left arm (ViperX)
|
||||
left_arm_max_relative_target: float | dict[str, float] = 20.0
|
||||
left_arm_use_degrees: bool = True
|
||||
|
||||
# Optional parameters for right arm (ViperX)
|
||||
right_arm_max_relative_target: float | dict[str, float] = 20.0
|
||||
right_arm_use_degrees: bool = True
|
||||
|
||||
# cameras (shared between both arms)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
@@ -43,6 +43,3 @@ class ViperXConfig(RobotConfig):
|
||||
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
||||
# data recording due to bandwidth limit, you might need to plug the camera
|
||||
# on another USB hub or PCIe card.
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -18,6 +18,7 @@ from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.dynamixel import (
|
||||
@@ -44,23 +45,22 @@ class ViperX(Robot):
|
||||
self,
|
||||
config: ViperXConfig,
|
||||
):
|
||||
raise NotImplementedError
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
|
||||
self.bus = DynamixelMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"waist": Motor(1, "xm540-w270", norm_mode_body),
|
||||
"shoulder": Motor(2, "xm540-w270", norm_mode_body),
|
||||
"shoulder_shadow": Motor(3, "xm540-w270", norm_mode_body),
|
||||
"elbow": Motor(4, "xm540-w270", norm_mode_body),
|
||||
"elbow_shadow": Motor(5, "xm540-w270", norm_mode_body),
|
||||
"forearm_roll": Motor(6, "xm540-w270", norm_mode_body),
|
||||
"wrist_angle": Motor(7, "xm540-w270", norm_mode_body),
|
||||
"wrist_rotate": Motor(8, "xm430-w350", norm_mode_body),
|
||||
"waist": Motor(1, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_shadow": Motor(3, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"elbow": Motor(4, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_shadow": Motor(5, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"forearm_roll": Motor(6, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_angle": Motor(7, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_rotate": Motor(8, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"gripper": Motor(9, "xm430-w350", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@@ -96,9 +96,6 @@ class ViperX(Robot):
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
@@ -112,24 +109,16 @@ class ViperX(Robot):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
input(f"Move {self} to the middle of its range of motion and press ENTER....")
|
||||
input("Move robot to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.bus.set_half_turn_homings()
|
||||
|
||||
full_turn_motors = ["shoulder", "forearm_roll", "wrist_rotate"]
|
||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||
unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors]
|
||||
print(
|
||||
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
||||
@@ -164,23 +153,33 @@ class ViperX(Robot):
|
||||
self.bus.write("Secondary_ID", "shoulder_shadow", 2)
|
||||
self.bus.write("Secondary_ID", "elbow_shadow", 4)
|
||||
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
# Set a velocity limit of 131 as advised by Trossen Robotics
|
||||
# TODO(aliberts): remove as it's actually useless in position control
|
||||
self.bus.write("Velocity_Limit", 131)
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
||||
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
|
||||
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point.
|
||||
# See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
|
||||
for motor in self.bus.motors:
|
||||
if motor != "gripper":
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
# TODO(pepijn): Re enable this
|
||||
# Use 'position control current based' for follower gripper to be limited by the limit of the
|
||||
# current. It can grasp an object without forcing too much even tho, it's goal position is a
|
||||
# complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict[OBS_STATE] = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .aloha_teleop import AlohaTeleop
|
||||
from .config_aloha_teleop import AlohaTeleopConfig
|
||||
|
||||
__all__ = ["AlohaTeleop", "AlohaTeleopConfig"]
|
||||
@@ -1,125 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.widowx.config_widowx import WidowXConfig
|
||||
from lerobot.teleoperators.widowx.widowx import WidowX
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_aloha_teleop import AlohaTeleopConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlohaTeleop(Teleoperator):
|
||||
"""
|
||||
ALOHA Bimanual Teleoperator System using dual WidowX leader arms.
|
||||
Based on the ALOHA (A Low-cost Open-source Hardware System for Bimanual Teleoperation) design.
|
||||
"""
|
||||
|
||||
config_class = AlohaTeleopConfig
|
||||
name = "aloha_teleop"
|
||||
|
||||
def __init__(self, config: AlohaTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = WidowXConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_port,
|
||||
gripper_motor=config.left_arm_gripper_motor,
|
||||
use_degrees=config.left_arm_use_degrees,
|
||||
)
|
||||
|
||||
right_arm_config = WidowXConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_port,
|
||||
gripper_motor=config.right_arm_gripper_motor,
|
||||
use_degrees=config.right_arm_use_degrees,
|
||||
)
|
||||
|
||||
self.left_arm = WidowX(left_arm_config)
|
||||
self.right_arm = WidowX(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | {
|
||||
f"right_{motor}.pos": float for motor in self.right_arm.bus.motors
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_action = self.left_arm.get_action()
|
||||
action_dict.update({f"left_{key}": value for key, value in left_action.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_action = self.right_arm.get_action()
|
||||
action_dict.update({f"right_{key}": value for key, value in right_action.items()})
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# Remove "left_" prefix
|
||||
left_feedback = {
|
||||
key.removeprefix("left_"): value for key, value in feedback.items() if key.startswith("left_")
|
||||
}
|
||||
# Remove "right_" prefix
|
||||
right_feedback = {
|
||||
key.removeprefix("right_"): value for key, value in feedback.items() if key.startswith("right_")
|
||||
}
|
||||
|
||||
if left_feedback:
|
||||
self.left_arm.send_feedback(left_feedback)
|
||||
if right_feedback:
|
||||
self.right_arm.send_feedback(right_feedback)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("aloha_teleop")
|
||||
@dataclass
|
||||
class AlohaTeleopConfig(TeleoperatorConfig):
|
||||
left_arm_port: str
|
||||
right_arm_port: str
|
||||
|
||||
# Parameters for left arm (WidowX)
|
||||
left_arm_gripper_motor: str = "xl430-w250"
|
||||
left_arm_use_degrees: bool = True
|
||||
|
||||
# Parameters for right arm (WidowX)
|
||||
right_arm_gripper_motor: str = "xc430-w150"
|
||||
right_arm_use_degrees: bool = True
|
||||
@@ -23,7 +23,3 @@ from ..config import TeleoperatorConfig
|
||||
@dataclass
|
||||
class WidowXConfig(TeleoperatorConfig):
|
||||
port: str # Port to connect to the arm
|
||||
|
||||
gripper_motor: str = "xl430-w250" # This could be xc430-w150 or xl430-w250
|
||||
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -20,6 +20,7 @@ import time
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.dynamixel import (
|
||||
DriveMode,
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
@@ -39,27 +40,22 @@ class WidowX(Teleoperator):
|
||||
name = "widowx"
|
||||
|
||||
def __init__(self, config: WidowXConfig):
|
||||
raise NotImplementedError
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
|
||||
self.bus = DynamixelMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"waist": Motor(1, "xm430-w350", norm_mode_body),
|
||||
"shoulder": Motor(2, "xm430-w350", norm_mode_body),
|
||||
"shoulder_shadow": Motor(3, "xm430-w350", norm_mode_body),
|
||||
"elbow": Motor(4, "xm430-w350", norm_mode_body),
|
||||
"elbow_shadow": Motor(5, "xm430-w350", norm_mode_body),
|
||||
"forearm_roll": Motor(6, "xm430-w350", norm_mode_body),
|
||||
"wrist_angle": Motor(7, "xm430-w350", norm_mode_body),
|
||||
"wrist_rotate": Motor(
|
||||
8, "xm430-w350", norm_mode_body
|
||||
), # This could be xl430-w250 or xm430-w350
|
||||
"gripper": Motor(
|
||||
9, self.config.gripper_motor, MotorNormMode.RANGE_0_100
|
||||
), # This could be xc430-w150 or xl430-w250
|
||||
"waist": Motor(1, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder": Motor(2, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_shadow": Motor(3, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"elbow": Motor(4, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_shadow": Motor(5, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"forearm_roll": Motor(6, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_angle": Motor(7, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_rotate": Motor(8, "xl430-w250", MotorNormMode.RANGE_M100_100),
|
||||
"gripper": Motor(9, "xc430-w150", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -80,9 +76,6 @@ class WidowX(Teleoperator):
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
@@ -93,27 +86,19 @@ class WidowX(Teleoperator):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch)
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
# self.bus.write("Drive_Mode", "el", DriveMode.INVERTED.value)
|
||||
# drive_modes = {motor: 1 if motor == ["elbow_shadow", "shoulder_shadow"] else 0 for motor in self.bus.motors}
|
||||
self.bus.write("Drive_Mode", "elbow_flex", DriveMode.INVERTED.value)
|
||||
drive_modes = {motor: 1 if motor == "elbow_flex" else 0 for motor in self.bus.motors}
|
||||
|
||||
input(f"Move {self} to the middle of its range of motion and press ENTER....")
|
||||
input("Move robot to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.bus.set_half_turn_homings()
|
||||
|
||||
full_turn_motors = ["shoulder", "forearm_roll", "wrist_rotate"]
|
||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||
unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors]
|
||||
print(
|
||||
f"Move all joints except {full_turn_motors} sequentially through their "
|
||||
@@ -128,7 +113,7 @@ class WidowX(Teleoperator):
|
||||
for motor, m in self.bus.motors.items():
|
||||
self.calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=homing_offsets[motor],
|
||||
range_min=range_mins[motor],
|
||||
range_max=range_maxes[motor],
|
||||
@@ -148,22 +133,6 @@ class WidowX(Teleoperator):
|
||||
self.bus.write("Secondary_ID", "shoulder_shadow", 2)
|
||||
self.bus.write("Secondary_ID", "elbow_shadow", 4)
|
||||
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
# TODO(pepijn): Re enable this
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
# self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
|
||||
# self.bus.enable_torque("gripper")
|
||||
|
||||
if self.is_calibrated:
|
||||
self.bus.write("Goal_Position", "gripper", self.config.gripper_open_pos)
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
Reference in New Issue
Block a user