mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
add annotation pipeline
This commit is contained in:
@@ -58,6 +58,7 @@ from lerobot.datasets.utils import (
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
load_tasks_high_level,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -162,6 +163,7 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.tasks_high_level = load_tasks_high_level(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -1060,6 +1062,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
|
||||
# optionally add high level task index
|
||||
if "task_index_high_level" in self.features:
|
||||
high_level_task_idx = item["task_index_high_level"].item()
|
||||
item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"]
|
||||
item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"]
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -61,6 +61,7 @@ VIDEO_DIR = "videos"
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
@@ -352,6 +353,9 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH)
|
||||
return tasks
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi05 import PI05Config
|
||||
from .modeling_pi05 import PI05Policy
|
||||
from .processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
print(batch.keys())
|
||||
print(batch['task_index_high_level'].shape)
|
||||
print(batch['task_index_high_level'])
|
||||
print(batch['user_prompt'][0])
|
||||
print(batch['robot_utterance'][0])
|
||||
print(batch['task'][0])
|
||||
breakpoint()
|
||||
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="jadechoghari/collect-data"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen"
|
||||
BATCH_SIZE=32
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# run synthetic data generation (all episodes processed)
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--model "$MODEL" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--temperature "$TEMPERATURE" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--sample-interval "$SAMPLE_INTERVAL" \
|
||||
--image-key observation.images.base \
|
||||
--num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
class PI05Config(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_pre_post_processors(
|
||||
config: PI05Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user