mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
feet(pi0/pi0.5): add pipeline (#2009)
* feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * refactor(pi05): update imports and rename configuration classes - Changed imports to reflect the new naming convention for PI05 configuration and policy classes. - Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency. - Introduced a new processor file for PI05, implementing pre-processing and post-processing steps. - Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase. * update(pi05): increase tokenizer_max_length for improved processing - Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences. - This adjustment aims to improve the overall performance and flexibility of the PI05 configuration. * add default for state (max_state_dim) * correct naming * fix import * cleanup code * remove unused test * us quantiles for action * move to device * remove discrete state assert * fix pi05 test * move pi05 to device * use base models in comparison tests * small renames for tests * change number of tokens pi05 test * fix openpi tokenization in test * fix hub test * fix test * assert lerobot vs openpi tests --------- Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -268,6 +268,22 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
processors = make_pi05_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
|
||||
@@ -16,5 +16,6 @@
|
||||
|
||||
from .configuration_pi0 import PI0Config
|
||||
from .modeling_pi0 import PI0Policy
|
||||
from .processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
__all__ = ["PI0Config", "PI0Policy"]
|
||||
__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]
|
||||
|
||||
@@ -24,15 +24,13 @@ from typing import Literal
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
@@ -50,7 +48,7 @@ def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (e
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
@@ -848,31 +846,15 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__( # see lerobot pi0 `__init__`
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
dataset_stats: Dataset statistics to be used for normalization.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Create tokenizer for language input
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Set max token length for tokenizer (from OpenPI)
|
||||
self.max_token_len = config.tokenizer_max_length
|
||||
|
||||
# Initialize the core PI0 model
|
||||
self.model = PI0Pytorch(config)
|
||||
|
||||
@@ -880,6 +862,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
@@ -923,8 +907,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline
|
||||
model = cls(config, dataset_stats=dataset_stats, **kwargs)
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
@@ -962,10 +945,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
remap_count = 0
|
||||
|
||||
for key, value in fixed_state_dict.items():
|
||||
if not key.startswith("model.") and not any(
|
||||
key.startswith(prefix)
|
||||
for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."]
|
||||
):
|
||||
if not key.startswith("model."):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
@@ -1140,44 +1120,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def _tokenize_language(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
|
||||
"""Tokenize language input using PaliGemma tokenizer."""
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
||||
tasks = tasks * batch_size
|
||||
else:
|
||||
# Default task if not provided
|
||||
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
||||
tasks = ["Pick up the object"] * batch_size
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = self.tokenizer(
|
||||
tasks,
|
||||
padding="max_length", # Use max_length padding as per OpenPI
|
||||
padding_side="right", # from lerobot pi0 `prepare_language`
|
||||
truncation=True,
|
||||
max_length=self.max_token_len, # Use the max token length from config
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
@@ -1206,11 +1148,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
@@ -1220,17 +1160,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward`
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Ensures that the task description string ends with a newline character.
|
||||
|
||||
This processing step is required for compatibility with the PaliGemma tokenizer,
|
||||
which expects a newline at the end of the text prompt. It handles both single
|
||||
strings and lists of strings for the 'task' key in complementary data.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""
|
||||
Adds a newline to the 'task' field if it doesn't already have one.
|
||||
|
||||
Args:
|
||||
complementary_data: A dictionary that may contain a 'task' key with a
|
||||
string or list of strings.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the modified 'task' field.
|
||||
"""
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
|
||||
Args:
|
||||
features: The input feature dictionary.
|
||||
|
||||
Returns:
|
||||
The unchanged feature dictionary.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -16,5 +16,6 @@
|
||||
|
||||
from .configuration_pi05 import PI05Config
|
||||
from .modeling_pi05 import PI05Policy
|
||||
from .processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
__all__ = ["PI05Config", "PI05Policy"]
|
||||
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
class PI05OpenPIConfig(PreTrainedConfig):
|
||||
class PI05Config(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
@@ -56,12 +56,14 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # pi0.5=48, see openpi `__post_init__`
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
"VISUAL": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for images
|
||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -19,21 +19,18 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
@@ -53,7 +50,7 @@ def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (e
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
@@ -814,7 +811,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
|
||||
class PI05Policy(PreTrainedPolicy):
|
||||
"""PI05 OpenPI Policy for LeRobot."""
|
||||
"""PI05 Policy for LeRobot."""
|
||||
|
||||
config_class = PI05Config
|
||||
name = "pi05"
|
||||
@@ -822,31 +819,15 @@ class PI05Policy(PreTrainedPolicy):
|
||||
def __init__( # see lerobot pi0 `__init__`
|
||||
self,
|
||||
config: PI05Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
dataset_stats: Dataset statistics to be used for normalization.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Create tokenizer for language input
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Set max token length for tokenizer (from OpenPI)
|
||||
self.max_token_len = config.tokenizer_max_length
|
||||
|
||||
# Initialize the core PI05 model
|
||||
self.model = PI05Pytorch(config)
|
||||
|
||||
@@ -854,6 +835,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
@@ -897,8 +880,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline
|
||||
model = cls(config, dataset_stats=dataset_stats, **kwargs)
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
@@ -936,10 +918,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
remap_count = 0
|
||||
|
||||
for key, value in fixed_state_dict.items():
|
||||
if not key.startswith("model.") and not any(
|
||||
key.startswith(prefix)
|
||||
for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."]
|
||||
):
|
||||
if not key.startswith("model."):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
@@ -1118,63 +1097,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def _tokenize_language_and_state(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
|
||||
"""Tokenize language input using PaliGemma tokenizer."""
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
||||
tasks = tasks * batch_size
|
||||
else:
|
||||
# Default task if not provided
|
||||
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
||||
tasks = ["Pick up the object"] * batch_size
|
||||
|
||||
# Handle discrete state input for PI05 (always the case for pi05)
|
||||
# Get state from batch and discretize it
|
||||
state: Any | None = batch.get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("Robot state is required for PI05")
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize the full prompts with state
|
||||
tokenized = self.tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=self.max_token_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokens = tokenized["input_ids"].to(device)
|
||||
masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
return tokens, masks
|
||||
|
||||
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
@@ -1198,11 +1120,9 @@ class PI05Policy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
@@ -1211,17 +1131,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward`
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_pre_post_processors(
|
||||
config: PI05Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -302,6 +302,65 @@ def clean_state_dict(
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_state_dict_with_missing_key_handling(
|
||||
policy: torch.nn.Module,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
policy_type: str,
|
||||
known_missing_keys_whitelist: dict[str, list[str]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Load state dict into policy with graceful handling of missing keys.
|
||||
|
||||
This function loads the state dict with strict=False, filters out whitelisted
|
||||
missing keys, and provides detailed reporting about any issues found.
|
||||
|
||||
Args:
|
||||
policy: The policy model to load the state dict into.
|
||||
state_dict: The cleaned state dictionary to load.
|
||||
policy_type: The type of policy (used for whitelist lookup).
|
||||
known_missing_keys_whitelist: Dictionary mapping policy types to lists of
|
||||
known acceptable missing keys.
|
||||
|
||||
Returns:
|
||||
List of problematic missing keys that weren't in the whitelist.
|
||||
"""
|
||||
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
|
||||
load_result = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Check for missing keys
|
||||
missing_keys = load_result.missing_keys
|
||||
unexpected_keys = load_result.unexpected_keys
|
||||
|
||||
# Filter out whitelisted missing keys
|
||||
policy_type_lower = policy_type.lower()
|
||||
whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
|
||||
problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
|
||||
|
||||
if missing_keys:
|
||||
if problematic_missing_keys:
|
||||
print(f"⚠️ WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
|
||||
for key in problematic_missing_keys:
|
||||
print(f" - {key}")
|
||||
|
||||
if len(missing_keys) > len(problematic_missing_keys):
|
||||
whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
|
||||
print(f"ℹ️ INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
|
||||
for key in whitelisted_missing:
|
||||
print(f" - {key}")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"⚠️ WARNING: Found {len(unexpected_keys)} unexpected keys:")
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✅ Successfully loaded cleaned state dict into policy model (all keys matched)")
|
||||
else:
|
||||
print("⚠️ State dict loaded with some missing/unexpected keys (see details above)")
|
||||
|
||||
return problematic_missing_keys
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||
@@ -335,9 +394,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
return converted_features
|
||||
|
||||
|
||||
def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
|
||||
"""
|
||||
Display final migration summary with warnings about problematic missing keys.
|
||||
|
||||
Args:
|
||||
problematic_missing_keys: List of missing keys that weren't in the whitelist.
|
||||
"""
|
||||
if not problematic_missing_keys:
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🚨 IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
|
||||
)
|
||||
print()
|
||||
for key in problematic_missing_keys:
|
||||
print(f" ❌ {key}")
|
||||
print()
|
||||
print("These missing keys may indicate:")
|
||||
print(" • The model architecture has changed")
|
||||
print(" • Some components were not properly saved in the original model")
|
||||
print(" • The migration script needs to be updated for this policy type")
|
||||
print()
|
||||
print("What to do next:")
|
||||
print(" 1. Test your migrated model carefully to ensure it works as expected")
|
||||
print(" 2. If you encounter issues, please open an issue at:")
|
||||
print(" https://github.com/huggingface/lerobot/issues")
|
||||
print(" 3. Include this migration log and the missing keys listed above")
|
||||
print()
|
||||
print("If the model works correctly despite these warnings, the missing keys")
|
||||
print("might be expected for your policy type and can be added to the whitelist.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str | None = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
|
||||
"""
|
||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||
|
||||
@@ -347,13 +442,12 @@ def load_model_from_hub(
|
||||
|
||||
Returns:
|
||||
A tuple containing the model's state dictionary, the policy configuration,
|
||||
and the training configuration.
|
||||
and the training configuration (None if train_config.json is not found).
|
||||
"""
|
||||
# Download files.
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
@@ -362,8 +456,14 @@ def load_model_from_hub(
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
# Try to load train_config (optional)
|
||||
train_config = None
|
||||
try:
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print("train_config.json not found - continuing without training configuration")
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
@@ -409,8 +509,15 @@ def main():
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
# Try to load train_config (optional)
|
||||
train_config = None
|
||||
train_config_path = os.path.join(args.pretrained_path, "train_config.json")
|
||||
if os.path.exists(train_config_path):
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
print("train_config.json not found - continuing without training configuration")
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
@@ -487,10 +594,20 @@ def main():
|
||||
policy_class = get_policy_class(policy_type)
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
# Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
|
||||
known_missing_keys_whitelist = {
|
||||
"pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
|
||||
# Add other policy types and their known missing keys here as needed
|
||||
}
|
||||
|
||||
# Load state dict with graceful missing key handling
|
||||
problematic_missing_keys = load_state_dict_with_missing_key_handling(
|
||||
policy=policy,
|
||||
state_dict=new_state_dict,
|
||||
policy_type=policy_type,
|
||||
known_missing_keys_whitelist=known_missing_keys_whitelist,
|
||||
)
|
||||
policy.to(torch.float32)
|
||||
# Create preprocessor and postprocessor using the factory
|
||||
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
|
||||
@@ -520,7 +637,9 @@ def main():
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
dataset_repo_id = "unknown"
|
||||
if train_config is not None:
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
@@ -641,6 +760,9 @@ final_action = postprocessor(action)
|
||||
else:
|
||||
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
# Display final summary about any problematic missing keys
|
||||
display_migration_summary_with_warnings(problematic_missing_keys)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -7,26 +7,28 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi05 import ( # noqa: E402
|
||||
PI05Config,
|
||||
PI05Policy,
|
||||
make_pi05_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_model_architecture():
|
||||
"""Test that pi05=True creates the correct model architecture."""
|
||||
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
config = PI05Config(
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
dtype="float32",
|
||||
)
|
||||
set_seed(42)
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -52,9 +54,6 @@ def test_pi05_model_architecture():
|
||||
assert config.tokenizer_max_length == 200, (
|
||||
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
|
||||
)
|
||||
assert config.discrete_state_input == True, ( # noqa: E712
|
||||
f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}"
|
||||
)
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
@@ -73,7 +72,35 @@ def test_pi05_model_architecture():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config, dataset_stats)
|
||||
policy = PI05Policy(config)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
device = config.device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
action = postprocessor(action)
|
||||
print(f"Action: {action}")
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
raise
|
||||
|
||||
# Verify pi05 model components exist
|
||||
# Check that time_mlp layers exist (for AdaRMS conditioning)
|
||||
@@ -100,88 +127,18 @@ def test_pi05_model_architecture():
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_forward_pass():
|
||||
"""Test forward pass with"""
|
||||
|
||||
# Create config
|
||||
config = PI05Config(
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
dtype="float32",
|
||||
chunk_size=16, # Shorter chunk_size for testing
|
||||
n_action_steps=16, # Shorter action steps for testing
|
||||
)
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(14,),
|
||||
),
|
||||
"observation.images.base_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,),
|
||||
),
|
||||
}
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config, dataset_stats)
|
||||
|
||||
# Create test batch
|
||||
batch_size = 2
|
||||
device = next(policy.parameters()).device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
# Test forward pass
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
assert not torch.isnan(loss), "Loss is NaN"
|
||||
assert loss.item() >= 0, "Loss should be non-negative"
|
||||
config = make_policy_config(
|
||||
policy_type="pi0",
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
)
|
||||
print("Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
print(f" PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" Action expert variant: {config.action_expert_variant}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
# Test action prediction
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
# When batch_size > 1, select_action returns (batch_size, action_dim)
|
||||
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
|
||||
assert not torch.isnan(action).any(), "Action contains NaN values"
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
print(f"Config creation failed: {e}")
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,419 @@
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
"base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"left_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"right_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI05BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
precision: str = "float32"
|
||||
pi05: bool = True
|
||||
dtype: str = "float32"
|
||||
|
||||
|
||||
def instantiate_lerobot_pi05(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi05_base", strict=True)
|
||||
else:
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||
config = PI05BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi05_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="pepijn223/pi05_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
class PI05Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description (pi05 processor handles all text formatting)
|
||||
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks] * batch_size
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
|
||||
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create pi05-formatted prompts that include state information
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi05_original_vs_lerobot():
|
||||
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi05.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
@@ -14,13 +14,19 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
PI0Policy,
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
@@ -61,11 +67,11 @@ def test_policy_instantiation():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config, dataset_stats)
|
||||
|
||||
policy = PI0Policy(config)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = policy.device if hasattr(policy, "device") else "cpu"
|
||||
device = config.device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
@@ -74,7 +80,7 @@ def test_policy_instantiation():
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
@@ -85,6 +91,8 @@ def test_policy_instantiation():
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
action = postprocessor(action)
|
||||
print(f"Action: {action}")
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -22,7 +24,10 @@ from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
@@ -33,23 +38,33 @@ DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
"base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"left_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"right_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -65,27 +80,26 @@ class PI0BaseOriginalConfig:
|
||||
dtype: str = "float32"
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(from_pretrained: bool = False):
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True)
|
||||
# Then reinitialize the normalization with proper stats
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base", strict=True)
|
||||
else:
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config, DUMMY_DATASET_STATS)
|
||||
policy = PI0Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
@@ -94,7 +108,7 @@ def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = No
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...")
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
@@ -105,7 +119,7 @@ def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = No
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model")
|
||||
cache_dir = snapshot_download(repo_id="pepijn223/pi0_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
@@ -178,7 +192,7 @@ def create_dummy_data():
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt],
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
@@ -232,13 +246,22 @@ def create_original_observation_with_openpi_preprocessing(batch):
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object"] * batch_size
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
@@ -324,16 +347,19 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model
|
||||
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test 1: Each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTEST 1: Each model with its own preprocessing")
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
@@ -353,16 +379,24 @@ def test_pi0_original_vs_lerobot():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch)
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
@@ -371,29 +405,6 @@ def test_pi0_original_vs_lerobot():
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
# Test 2: Both models with LeRobot preprocessing (isolates model differences)
|
||||
print("\nTEST 2: Both models with LeRobot preprocessing (model comparison)")
|
||||
print("Creating observation for OpenPI using LeRobot's preprocessing...")
|
||||
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
|
||||
|
||||
print("Testing OpenPI with LeRobot preprocessing...")
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
with torch.no_grad():
|
||||
openpi_actions_lerobot_preproc = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}")
|
||||
print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}")
|
||||
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
|
||||
|
||||
print("\nComparing models with same preprocessing:")
|
||||
is_close_1e4 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)
|
||||
is_close_1e2 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)
|
||||
max_diff = torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item()
|
||||
|
||||
print(f"Actions close (atol=1e-4): {is_close_1e4}")
|
||||
print(f"Actions close (atol=1e-2): {is_close_1e2}")
|
||||
print(f"Max absolute difference: {max_diff:.6f}")
|
||||
|
||||
# Add assertions for pytest
|
||||
assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
@@ -19,7 +19,9 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
|
||||
|
||||
def create_dummy_stats(config):
|
||||
@@ -48,13 +50,11 @@ def create_dummy_stats(config):
|
||||
# Test data for all 6 base models
|
||||
MODEL_TEST_PARAMS = [
|
||||
# PI0 models
|
||||
("pepijn223/pi0_base_fp32", "PI0", PI0Policy),
|
||||
("pepijn223/pi0_droid_fp32", "PI0", PI0Policy),
|
||||
("pepijn223/pi0_libero_fp32", "PI0", PI0Policy),
|
||||
("pepijn223/pi0_base", "PI0", PI0Policy),
|
||||
("pepijn223/pi0_libero", "PI0", PI0Policy),
|
||||
# PI0.5 models
|
||||
("pepijn223/pi05_base_fp32", "PI0.5", PI05Policy),
|
||||
("pepijn223/pi05_droid_fp32", "PI0.5", PI05Policy),
|
||||
("pepijn223/pi05_libero_fp32", "PI0.5", PI05Policy),
|
||||
("pepijn223/pi05_base", "PI0.5", PI05Policy),
|
||||
("pepijn223/pi05_libero", "PI0.5", PI05Policy),
|
||||
]
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32")
|
||||
model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base")
|
||||
model_type: Model type ("PI0" or "PI0.5")
|
||||
policy_class: Policy class to use (PI0Policy or PI05Policy)
|
||||
"""
|
||||
@@ -79,37 +79,11 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
print(f"✗ Failed to load model {model_id}: {e}")
|
||||
raise
|
||||
|
||||
# Set up input_features and output_features in the config (not set by from_pretrained)
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
policy.config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(policy.config.max_state_dim,),
|
||||
),
|
||||
"observation.images.base_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
policy.config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(policy.config.max_action_dim,),
|
||||
),
|
||||
}
|
||||
|
||||
# Get model info
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
# Set device for policy config
|
||||
policy.config.device = device
|
||||
print("\nModel configuration:")
|
||||
print(f" - Model ID: {model_id}")
|
||||
print(f" - Model type: {model_type}")
|
||||
@@ -124,7 +98,6 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
|
||||
# Verify model-specific architecture
|
||||
if model_type == "PI0.5":
|
||||
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
||||
# Verify PI0.5 specific features
|
||||
assert hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0.5 should have time_mlp_in"
|
||||
assert hasattr(policy.model, "time_mlp_out"), f"{model_id}: PI0.5 should have time_mlp_out"
|
||||
@@ -155,18 +128,15 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
"std": stats["std"].to(device),
|
||||
}
|
||||
|
||||
# Initialize normalization layers with dummy stats
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
# Create processor pipeline based on model type
|
||||
if model_type == "PI0.5":
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=dummy_stats
|
||||
)
|
||||
else: # PI0
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=dummy_stats
|
||||
)
|
||||
|
||||
# Create test batch
|
||||
batch_size = 1
|
||||
@@ -188,11 +158,14 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
for key in policy.config.image_features.keys():
|
||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||
|
||||
# Process batch with pipeline
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
print(f"\nTesting forward pass for {model_id}...")
|
||||
try:
|
||||
policy.train()
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
loss, loss_dict = policy.forward(processed_batch)
|
||||
assert not torch.isnan(loss), f"{model_id}: Forward pass produced NaN loss"
|
||||
assert loss.item() >= 0, f"{model_id}: Loss should be non-negative"
|
||||
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
|
||||
@@ -205,11 +178,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
try:
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
expected_shape = (batch_size, policy.config.max_action_dim)
|
||||
assert action.shape == expected_shape, (
|
||||
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
|
||||
)
|
||||
action = policy.predict_action_chunk(processed_batch)
|
||||
assert not torch.isnan(action).any(), f"{model_id}: Action contains NaN values"
|
||||
print(f"✓ Action prediction successful - Shape: {action.shape}")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user