mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 09:07:03 +00:00
refactor(pi05): update imports and rename configuration classes
- Changed imports to reflect the new naming convention for PI05 configuration and policy classes. - Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency. - Introduced a new processor file for PI05, implementing pre-processing and post-processing steps. - Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase.
This commit is contained in:
@@ -87,11 +87,11 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.policies.pi0.modeling_pi0openpi import PI0Policy
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "pi05":
|
||||
from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
@@ -152,7 +152,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
elif policy_type == "pi0_openpi":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05_openpi":
|
||||
return PI05OpenPIConfig(**kwargs)
|
||||
return PI05Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -280,10 +280,10 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05OpenPIConfig):
|
||||
from lerobot.policies.pi05.processor_pi05openpi import make_pi05_openpi_pre_post_processors
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
processors = make_pi05_openpi_pre_post_processors(
|
||||
processors = make_pi05_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi0openpi import PI0Config
|
||||
from .modeling_pi0openpi import PI0Policy
|
||||
from .configuration_pi0 import PI0Config
|
||||
from .modeling_pi0 import PI0Policy
|
||||
from .processor_pi0_openpi import make_pi0_pre_post_processors
|
||||
|
||||
__all__ = ["PI0Config", "PI0Policy", "make_pi0_openpi_pre_post_processors"]
|
||||
__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0Config
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0Config
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
|
||||
@@ -16,5 +16,6 @@
|
||||
|
||||
from .configuration_pi05 import PI05Config
|
||||
from .modeling_pi05 import PI05Policy
|
||||
from .processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
__all__ = ["PI05Config", "PI05Policy"]
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
class PI05OpenPIConfig(PreTrainedConfig):
|
||||
class PI05Config(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
|
||||
@@ -492,7 +492,7 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05OpenPIConfig):
|
||||
def __init__(self, config: PI05Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -813,15 +813,15 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
|
||||
class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
class PI05Policy(PreTrainedPolicy):
|
||||
"""PI05 OpenPI Policy for LeRobot."""
|
||||
|
||||
config_class = PI05OpenPIConfig
|
||||
config_class = PI05Config
|
||||
name = "pi05"
|
||||
|
||||
def __init__( # see lerobot pi0 `__init__`
|
||||
self,
|
||||
config: PI05OpenPIConfig,
|
||||
config: PI05Config,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -858,7 +858,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
) -> T:
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
"⚠️ DISCLAIMER: The PI05 model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
|
||||
+4
-4
@@ -7,8 +7,8 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pi05.modeling_pi05openpi import pad_vector
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -77,8 +77,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_openpi_pre_post_processors(
|
||||
config: PI05OpenPIConfig,
|
||||
def make_pi05_pre_post_processors(
|
||||
config: PI05Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -23,15 +24,16 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05OpenPIConfig, PI05OpenPIPolicy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05openpi import make_pi05_openpi_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DUMMY_MAX_TOKEN_LEN = 200 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
@@ -83,30 +85,26 @@ class PI0BaseOriginalConfig:
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05OpenPIPolicy,
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05OpenPIPolicy.from_pretrained(
|
||||
pretrained_name_or_path="pepijn223/pi05_base_fp32", strict=True
|
||||
)
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi05_base_fp32", strict=True)
|
||||
else:
|
||||
config = PI05OpenPIConfig(
|
||||
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
|
||||
)
|
||||
policy = PI05OpenPIPolicy(config)
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_openpi_pre_post_processors(
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str | None = None):
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str | None = None) -> PI0Pytorch:
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
@@ -201,21 +199,6 @@ def create_dummy_data():
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
@@ -238,10 +221,34 @@ class PI0Observation:
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
# if state is not None:
|
||||
# # This is the Pi05 format, where the state is part of the discrete language input.
|
||||
# discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
# state_str = " ".join(map(str, discretized_state))
|
||||
# full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
# tokens = self._tokenizer.encode(full_prompt, add_bos=True)
|
||||
|
||||
|
||||
def encode_with_state(state: torch.Tensor, prompt: list[str], max_state_dim: int = 32) -> list[str]:
|
||||
state = deepcopy(state)
|
||||
state = pad_vector(state, max_state_dim)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_state = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
encoded_with_state = []
|
||||
for i, task in enumerate(prompt):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_state[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
encoded_with_state.append(full_prompt)
|
||||
return encoded_with_state
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
state = batch["observation.state"]
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
@@ -251,12 +258,9 @@ def create_original_observation_with_openpi_preprocessing(batch):
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
@@ -265,8 +269,8 @@ def create_original_observation_with_openpi_preprocessing(batch):
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
tasks = ["Pick up the object"] * batch_size
|
||||
tasks = encode_with_state(state=state, prompt=tasks)
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
@@ -313,41 +317,6 @@ def create_original_observation_with_openpi_preprocessing(batch):
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
@@ -408,30 +377,3 @@ def test_pi0_original_vs_lerobot():
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
# # Test 2: Both models with LeRobot preprocessing (isolates model differences)
|
||||
# print("\nTEST 2: Both models with LeRobot preprocessing (model comparison)")
|
||||
# print("Creating observation for OpenPI using LeRobot's preprocessing...")
|
||||
# pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
|
||||
|
||||
# print("Testing OpenPI with LeRobot preprocessing...")
|
||||
# torch.manual_seed(42) # Set seed for reproducibility
|
||||
# with torch.no_grad():
|
||||
# openpi_actions_lerobot_preproc = original_pi0.sample_actions(
|
||||
# device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10
|
||||
# )
|
||||
# print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}")
|
||||
# print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}")
|
||||
# print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
|
||||
|
||||
# print("\nComparing models with same preprocessing:")
|
||||
# is_close_1e4 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)
|
||||
# is_close_1e2 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)
|
||||
# max_diff = torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item()
|
||||
|
||||
# print(f"Actions close (atol=1e-4): {is_close_1e4}")
|
||||
# print(f"Actions close (atol=1e-2): {is_close_1e2}")
|
||||
# print(f"Max absolute difference: {max_diff:.6f}")
|
||||
|
||||
# # Add assertions for pytest
|
||||
# assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
|
||||
|
||||
Reference in New Issue
Block a user