mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0bda187268 | |||
| 59b33c0ea3 | |||
| 4419901e6b | |||
| 3f3a159cff | |||
| 6deebf1e47 | |||
| b9cb947bd2 | |||
| 481a956100 | |||
| ffde29be49 | |||
| 2504d00707 | |||
| 11cefed08a | |||
| 7bfedd1388 | |||
| 8c95a71c94 | |||
| 1d048c7e2b | |||
| 419305a4c2 | |||
| 753b996cda | |||
| 099f3ba4d7 | |||
| 3f3d08e5a8 | |||
| 9e1a67c862 | |||
| 54c38627bd | |||
| f0ef3717ca | |||
| bd8e1ccf70 |
@@ -61,6 +61,7 @@ jobs:
|
|||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
HF_HOME: /mnt/cache/.cache/huggingface
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
|
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
@@ -89,5 +90,10 @@ jobs:
|
|||||||
- name: Install lerobot with test extras
|
- name: Install lerobot with test extras
|
||||||
run: uv sync --extra "test"
|
run: uv sync --extra "test"
|
||||||
|
|
||||||
|
- name: Login to Hugging Face
|
||||||
|
run: |
|
||||||
|
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
|
uv run hf auth whoami
|
||||||
|
|
||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: uv run pytest tests -vv --maxfail=10
|
run: uv run pytest tests -vv --maxfail=10
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ jobs:
|
|||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
HF_HOME: /mnt/cache/.cache/huggingface
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
|
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
@@ -87,6 +88,11 @@ jobs:
|
|||||||
- name: Install lerobot with all extras
|
- name: Install lerobot with all extras
|
||||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||||
|
|
||||||
|
- name: Login to Hugging Face
|
||||||
|
run: |
|
||||||
|
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
|
uv run hf auth whoami
|
||||||
|
|
||||||
- name: Run pytest (all extras)
|
- name: Run pytest (all extras)
|
||||||
run: uv run pytest tests -vv --maxfail=10
|
run: uv run pytest tests -vv --maxfail=10
|
||||||
|
|
||||||
@@ -162,6 +168,7 @@ jobs:
|
|||||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||||
|
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||||
container:
|
container:
|
||||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||||
options: --gpus all --shm-size "16gb"
|
options: --gpus all --shm-size "16gb"
|
||||||
@@ -173,6 +180,10 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
|
- name: Login to Hugging Face
|
||||||
|
run: |
|
||||||
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
|
hf auth whoami
|
||||||
- name: Run pytest on GPU
|
- name: Run pytest on GPU
|
||||||
run: pytest tests -vv --maxfail=10
|
run: pytest tests -vv --maxfail=10
|
||||||
- name: Run end-to-end tests
|
- name: Run end-to-end tests
|
||||||
|
|||||||
+10
-10
@@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
|||||||
|
|
||||||
You have two options for the FAST tokenizer:
|
You have two options for the FAST tokenizer:
|
||||||
|
|
||||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||||
|
|
||||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||||
|
|
||||||
@@ -114,15 +114,15 @@ lerobot-train \
|
|||||||
|
|
||||||
### Key Training Parameters
|
### Key Training Parameters
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
|
||||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
|
||||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||||
|
|
||||||
## Inference
|
## Inference
|
||||||
|
|
||||||
|
|||||||
+11
-93
@@ -61,7 +61,7 @@ dependencies = [
|
|||||||
# Hugging Face dependencies
|
# Hugging Face dependencies
|
||||||
"datasets>=4.0.0,<5.0.0",
|
"datasets>=4.0.0,<5.0.0",
|
||||||
"diffusers>=0.27.2,<0.36.0",
|
"diffusers>=0.27.2,<0.36.0",
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
"huggingface-hub[cli]>=1.0.0,<2.0.0",
|
||||||
"accelerate>=1.10.0,<2.0.0",
|
"accelerate>=1.10.0,<2.0.0",
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
@@ -96,7 +96,7 @@ dependencies = [
|
|||||||
# Common
|
# Common
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
transformers-dep = ["transformers>=5.1.0,<6.0.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
|
|
||||||
@@ -129,13 +129,13 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
|||||||
|
|
||||||
# Policies
|
# Policies
|
||||||
wallx = [
|
wallx = [
|
||||||
"transformers==4.49.0",
|
"lerobot[transformers-dep]",
|
||||||
"peft==0.17.1",
|
"peft>=0.18.0,<1.0.0",
|
||||||
"scipy==1.15.3",
|
"scipy==1.15.3", # TODO: Relax version
|
||||||
"torchdiffeq==0.2.5",
|
"torchdiffeq==0.2.5", # TODO: Relax version
|
||||||
"qwen_vl_utils==0.0.11"
|
"qwen-vl-utils==0.0.11" # TODO: Relax version
|
||||||
]
|
]
|
||||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
pi = ["lerobot[transformers-dep]", "scipy==1.15.3"] # TODO: Relax scipy version
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||||
groot = [
|
groot = [
|
||||||
"lerobot[transformers-dep]",
|
"lerobot[transformers-dep]",
|
||||||
@@ -148,7 +148,7 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
@@ -176,8 +176,8 @@ all = [
|
|||||||
"lerobot[reachy2]",
|
"lerobot[reachy2]",
|
||||||
"lerobot[kinematics]",
|
"lerobot[kinematics]",
|
||||||
"lerobot[intelrealsense]",
|
"lerobot[intelrealsense]",
|
||||||
# "lerobot[wallx]",
|
"lerobot[wallx]",
|
||||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
"lerobot[pi]",
|
||||||
"lerobot[smolvla]",
|
"lerobot[smolvla]",
|
||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
"lerobot[xvla]",
|
"lerobot[xvla]",
|
||||||
@@ -394,85 +394,3 @@ ignore_errors = false
|
|||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.scripts.*"
|
# module = "lerobot.scripts.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
[tool.uv]
|
|
||||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
|
||||||
conflicts = [
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "transformers-dep" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "pi" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "smolvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "groot" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "xvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "sarm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "hilserl" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "libero" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "peft" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "all" },
|
|
||||||
],
|
|
||||||
# pi uses custom branch which conflicts with transformers-dep
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "transformers-dep" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "smolvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "groot" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "xvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "sarm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "hilserl" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "libero" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "peft" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "all" },
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from transformers.image_processing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.image_processing_utils_fast import (
|
from transformers.image_processing_utils_fast import (
|
||||||
BaseImageProcessorFast,
|
BaseImageProcessorFast,
|
||||||
DefaultFastImageProcessorKwargs,
|
ImagesKwargs,
|
||||||
group_images_by_shape,
|
group_images_by_shape,
|
||||||
reorder_images,
|
reorder_images,
|
||||||
)
|
)
|
||||||
@@ -77,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
|
|||||||
return img[:, top:bottom, left:right]
|
return img[:, top:bottom, left:right]
|
||||||
|
|
||||||
|
|
||||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||||
max_dynamic_tiles: int | None
|
max_dynamic_tiles: int | None
|
||||||
min_dynamic_tiles: int | None
|
min_dynamic_tiles: int | None
|
||||||
use_thumbnail: bool | None
|
use_thumbnail: bool | None
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -32,13 +33,21 @@ from lerobot.utils.import_utils import _transformers_available
|
|||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
||||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaForCausalLM,
|
||||||
|
_gated_residual,
|
||||||
|
layernorm_forward,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
GemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
PaliGemmaForConditionalGeneration = None
|
_gated_residual = None
|
||||||
|
layernorm_forward = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||||
@@ -191,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype == torch.float32:
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(0.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
|
|
||||||
@@ -202,7 +211,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
pad_w1 = pad_w0 + remainder_w
|
pad_w1 = pad_w0 + remainder_w
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||||
padded_images = F.pad(
|
padded_images = F.pad(
|
||||||
resized_images,
|
resized_images,
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||||
@@ -221,14 +230,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
def compute_layer_complete(
|
def compute_layer_complete(
|
||||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
):
|
):
|
||||||
models = [paligemma.language_model, gemma_expert.model]
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = models[i].layers[layer_idx]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||||
@@ -254,10 +263,10 @@ def compute_layer_complete(
|
|||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
# Attention computation
|
# Attention computation
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
paligemma.language_model.layers[layer_idx].self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -265,7 +274,7 @@ def compute_layer_complete(
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
@@ -277,15 +286,15 @@ def compute_layer_complete(
|
|||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||||
# first residual
|
# first residual
|
||||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||||
after_first_residual = out_emb.clone()
|
after_first_residual = out_emb.clone()
|
||||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||||
out_emb = layer.mlp(out_emb)
|
out_emb = layer.mlp(out_emb)
|
||||||
# second residual
|
# second residual
|
||||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
start_pos = end_pos
|
start_pos = end_pos
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
@@ -358,7 +367,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
vlm_config_hf.text_config.dtype = "float32"
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
@@ -366,7 +375,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
vlm_config_hf.vision_config.dtype = "float32"
|
||||||
|
|
||||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||||
head_dim=action_expert_config.head_dim,
|
head_dim=action_expert_config.head_dim,
|
||||||
@@ -377,13 +386,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||||
vocab_size=257152,
|
vocab_size=257152,
|
||||||
hidden_activation="gelu_pytorch_tanh",
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
torch_dtype="float32",
|
dtype="float32",
|
||||||
use_adarms=use_adarms[1],
|
use_adarms=use_adarms[1],
|
||||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||||
self.gemma_expert.model.embed_tokens = None
|
self.gemma_expert.model.embed_tokens = None
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
@@ -398,10 +407,11 @@ class PaliGemmaWithExpertModel(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
raise ValueError(f"Invalid precision: {precision}")
|
||||||
|
|
||||||
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||||
|
# "same dtype" error). Align with PI05.
|
||||||
params_to_keep_float32 = [
|
params_to_keep_float32 = [
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
"vision_tower",
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
"multi_modal_projector",
|
||||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
||||||
"input_layernorm",
|
"input_layernorm",
|
||||||
"post_attention_layernorm",
|
"post_attention_layernorm",
|
||||||
"model.norm",
|
"model.norm",
|
||||||
@@ -413,8 +423,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
|
|
||||||
def _set_requires_grad(self):
|
def _set_requires_grad(self):
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
for param in self.paligemma.vision_tower.parameters():
|
for param in self.paligemma.model.vision_tower.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
@@ -424,15 +434,23 @@ class PaliGemmaWithExpertModel(
|
|||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
super().train(mode)
|
super().train(mode)
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
|
if features.dtype != out_dtype:
|
||||||
|
features = features.to(out_dtype)
|
||||||
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -446,7 +464,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
if adarms_cond is None:
|
if adarms_cond is None:
|
||||||
adarms_cond = [None, None]
|
adarms_cond = [None, None]
|
||||||
if inputs_embeds[1] is None:
|
if inputs_embeds[1] is None:
|
||||||
prefix_output = self.paligemma.language_model.forward(
|
prefix_output = self.paligemma.model.language_model.forward(
|
||||||
inputs_embeds=inputs_embeds[0],
|
inputs_embeds=inputs_embeds[0],
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -470,7 +488,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
@@ -510,7 +528,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -576,29 +594,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Also compile the main forward pass used during training
|
# Also compile the main forward pass used during training
|
||||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.siglip import check
|
|
||||||
|
|
||||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(msg) from None
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
def gradient_checkpointing_enable(self):
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
"""Enable gradient checkpointing for memory optimization."""
|
||||||
self.gradient_checkpointing_enabled = True
|
self.gradient_checkpointing_enabled = True
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
def gradient_checkpointing_disable(self):
|
||||||
"""Disable gradient checkpointing."""
|
"""Disable gradient checkpointing."""
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||||
|
|
||||||
@@ -760,7 +768,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -834,7 +842,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
_, past_key_values = self.paligemma_with_expert.forward(
|
||||||
attention_mask=prefix_att_2d_masks_4d,
|
attention_mask=prefix_att_2d_masks_4d,
|
||||||
@@ -908,6 +916,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -997,14 +1006,12 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Load state dict (expects keys with "model." prefix)
|
||||||
try:
|
try:
|
||||||
# Try to load the pytorch_model.bin or model.safetensors file
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
try:
|
try:
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
resolved_file = cached_file(
|
||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
@@ -1012,7 +1019,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
@@ -1025,7 +1032,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
print("Returning model without loading pretrained weights")
|
print("Returning model without loading pretrained weights")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||||
|
|
||||||
# Then add "model." prefix for all keys that don't already have it
|
# Then add "model." prefix for all keys that don't already have it
|
||||||
@@ -1070,7 +1077,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
print("All keys loaded successfully!")
|
print("All keys loaded successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not remap state dict keys: {e}")
|
print(f"Warning: Could not load state dict: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -1120,6 +1127,14 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
# Some checkpoints might have this, but current model expects different structure
|
# Some checkpoints might have this, but current model expects different structure
|
||||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
):
|
||||||
|
fixed_state_dict[
|
||||||
|
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||||
|
] = value.clone()
|
||||||
|
|
||||||
fixed_state_dict[new_key] = value
|
fixed_state_dict[new_key] = value
|
||||||
|
|
||||||
return fixed_state_dict
|
return fixed_state_dict
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -32,14 +33,20 @@ from lerobot.utils.import_utils import _transformers_available
|
|||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
||||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaForCausalLM,
|
||||||
|
_gated_residual,
|
||||||
|
layernorm_forward,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
GemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
PaliGemmaForConditionalGeneration = None
|
_gated_residual = None
|
||||||
|
layernorm_forward = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
@@ -189,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype == torch.float32:
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(0.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
|
|
||||||
@@ -200,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
pad_w1 = pad_w0 + remainder_w
|
pad_w1 = pad_w0 + remainder_w
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||||
padded_images = F.pad(
|
padded_images = F.pad(
|
||||||
resized_images,
|
resized_images,
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||||
@@ -219,14 +226,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
def compute_layer_complete(
|
def compute_layer_complete(
|
||||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
):
|
):
|
||||||
models = [paligemma.language_model, gemma_expert.model]
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = models[i].layers[layer_idx]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||||
@@ -252,10 +259,10 @@ def compute_layer_complete(
|
|||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
# Attention computation
|
# Attention computation
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
paligemma.language_model.layers[layer_idx].self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -263,7 +270,7 @@ def compute_layer_complete(
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
@@ -275,15 +282,15 @@ def compute_layer_complete(
|
|||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||||
# first residual
|
# first residual
|
||||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||||
after_first_residual = out_emb.clone()
|
after_first_residual = out_emb.clone()
|
||||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||||
out_emb = layer.mlp(out_emb)
|
out_emb = layer.mlp(out_emb)
|
||||||
# second residual
|
# second residual
|
||||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
start_pos = end_pos
|
start_pos = end_pos
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
@@ -356,7 +363,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
vlm_config_hf.text_config.dtype = "float32"
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
@@ -364,7 +371,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
vlm_config_hf.vision_config.dtype = "float32"
|
||||||
|
|
||||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||||
head_dim=action_expert_config.head_dim,
|
head_dim=action_expert_config.head_dim,
|
||||||
@@ -375,13 +382,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||||
vocab_size=257152,
|
vocab_size=257152,
|
||||||
hidden_activation="gelu_pytorch_tanh",
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
torch_dtype="float32",
|
dtype="float32",
|
||||||
use_adarms=use_adarms[1],
|
use_adarms=use_adarms[1],
|
||||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||||
self.gemma_expert.model.embed_tokens = None
|
self.gemma_expert.model.embed_tokens = None
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
@@ -396,10 +403,11 @@ class PaliGemmaWithExpertModel(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
raise ValueError(f"Invalid precision: {precision}")
|
||||||
|
|
||||||
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||||
|
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||||
params_to_keep_float32 = [
|
params_to_keep_float32 = [
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
"vision_tower",
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
"multi_modal_projector",
|
||||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
||||||
"input_layernorm",
|
"input_layernorm",
|
||||||
"post_attention_layernorm",
|
"post_attention_layernorm",
|
||||||
"model.norm",
|
"model.norm",
|
||||||
@@ -411,8 +419,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
|
|
||||||
def _set_requires_grad(self):
|
def _set_requires_grad(self):
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
for param in self.paligemma.vision_tower.parameters():
|
for param in self.paligemma.model.vision_tower.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
@@ -422,15 +430,23 @@ class PaliGemmaWithExpertModel(
|
|||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
super().train(mode)
|
super().train(mode)
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
|
if features.dtype != out_dtype:
|
||||||
|
features = features.to(out_dtype)
|
||||||
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -444,7 +460,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
if adarms_cond is None:
|
if adarms_cond is None:
|
||||||
adarms_cond = [None, None]
|
adarms_cond = [None, None]
|
||||||
if inputs_embeds[1] is None:
|
if inputs_embeds[1] is None:
|
||||||
prefix_output = self.paligemma.language_model.forward(
|
prefix_output = self.paligemma.model.language_model.forward(
|
||||||
inputs_embeds=inputs_embeds[0],
|
inputs_embeds=inputs_embeds[0],
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -468,7 +484,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
@@ -508,7 +524,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -573,29 +589,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Also compile the main forward pass used during training
|
# Also compile the main forward pass used during training
|
||||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.siglip import check
|
|
||||||
|
|
||||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(msg) from None
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
def gradient_checkpointing_enable(self):
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
"""Enable gradient checkpointing for memory optimization."""
|
||||||
self.gradient_checkpointing_enabled = True
|
self.gradient_checkpointing_enabled = True
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
def gradient_checkpointing_disable(self):
|
||||||
"""Disable gradient checkpointing."""
|
"""Disable gradient checkpointing."""
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
@@ -737,7 +743,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -808,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
_, past_key_values = self.paligemma_with_expert.forward(
|
||||||
attention_mask=prefix_att_2d_masks_4d,
|
attention_mask=prefix_att_2d_masks_4d,
|
||||||
@@ -880,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -969,14 +976,12 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Load state dict (expects keys with "model." prefix)
|
||||||
try:
|
try:
|
||||||
# Try to load the pytorch_model.bin or model.safetensors file
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
try:
|
try:
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
resolved_file = cached_file(
|
||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
@@ -984,7 +989,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
@@ -997,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
print("Returning model without loading pretrained weights")
|
print("Returning model without loading pretrained weights")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||||
|
|
||||||
# Then add "model." prefix for all keys that don't already have it
|
# Then add "model." prefix for all keys that don't already have it
|
||||||
@@ -1009,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
new_key = f"model.{key}"
|
new_key = f"model.{key}"
|
||||||
remapped_state_dict[new_key] = value
|
remapped_state_dict[new_key] = value
|
||||||
remap_count += 1
|
remap_count += 1
|
||||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
||||||
print(f"Remapped: {key} -> {new_key}")
|
|
||||||
else:
|
else:
|
||||||
remapped_state_dict[key] = value
|
remapped_state_dict[key] = value
|
||||||
|
|
||||||
@@ -1044,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
print("All keys loaded successfully!")
|
print("All keys loaded successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not remap state dict keys: {e}")
|
print(f"Warning: Could not load state dict: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -1098,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
# Some checkpoints might have this, but current model expects different structure
|
# Some checkpoints might have this, but current model expects different structure
|
||||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
):
|
||||||
|
fixed_state_dict[
|
||||||
|
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||||
|
] = value.clone()
|
||||||
|
|
||||||
fixed_state_dict[new_key] = value
|
fixed_state_dict[new_key] = value
|
||||||
|
|
||||||
return fixed_state_dict
|
return fixed_state_dict
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
|
|
||||||
# Prepare state (pad to max_state_dim)
|
|
||||||
state = pad_vector(state, self.max_state_dim)
|
|
||||||
|
|
||||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
state_np = state.cpu().numpy()
|
state_np = state.cpu().numpy()
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
max_decoding_steps: int = 256
|
max_decoding_steps: int = 256
|
||||||
fast_skip_tokens: int = 128
|
fast_skip_tokens: int = 128
|
||||||
|
|||||||
@@ -38,11 +38,16 @@ else:
|
|||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
||||||
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaModel,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
PaliGemmaForConditionalGeneration = None
|
|
||||||
AutoTokenizer = None
|
AutoTokenizer = None
|
||||||
|
PiGemmaModel = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||||
@@ -121,7 +126,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype == torch.float32:
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(0.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
|
|
||||||
@@ -132,7 +137,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
pad_w1 = pad_w0 + remainder_w
|
pad_w1 = pad_w0 + remainder_w
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||||
padded_images = F.pad(
|
padded_images = F.pad(
|
||||||
resized_images,
|
resized_images,
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||||
@@ -206,16 +211,22 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
vlm_config_hf.text_config.dtype = "float32"
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
vlm_config_hf.vision_config.dtype = "float32"
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||||
|
|
||||||
|
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
|
||||||
|
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||||
|
if use_adarms[0]:
|
||||||
|
text_config = self.paligemma.config.text_config
|
||||||
|
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
|
|
||||||
@@ -228,10 +239,11 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
raise ValueError(f"Invalid precision: {precision}")
|
||||||
|
|
||||||
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||||
|
# "same dtype" error). Align with PI05.
|
||||||
params_to_keep_float32 = [
|
params_to_keep_float32 = [
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
"vision_tower",
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
"multi_modal_projector",
|
||||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
||||||
"input_layernorm",
|
"input_layernorm",
|
||||||
"post_attention_layernorm",
|
"post_attention_layernorm",
|
||||||
"model.norm",
|
"model.norm",
|
||||||
@@ -242,10 +254,18 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
param.data = param.data.to(dtype=torch.float32)
|
param.data = param.data.to(dtype=torch.float32)
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
|
if features.dtype != out_dtype:
|
||||||
|
features = features.to(out_dtype)
|
||||||
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -259,7 +279,7 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
if adarms_cond is None:
|
if adarms_cond is None:
|
||||||
adarms_cond = [None, None]
|
adarms_cond = [None, None]
|
||||||
if inputs_embeds[1] is None:
|
if inputs_embeds[1] is None:
|
||||||
prefix_output = self.paligemma.language_model.forward(
|
prefix_output = self.paligemma.model.language_model.forward(
|
||||||
inputs_embeds=inputs_embeds[0],
|
inputs_embeds=inputs_embeds[0],
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -306,24 +326,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
||||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.siglip import check
|
|
||||||
|
|
||||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(msg) from None
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
def gradient_checkpointing_enable(self):
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
"""Enable gradient checkpointing for memory optimization."""
|
||||||
self.gradient_checkpointing_enabled = True
|
self.gradient_checkpointing_enabled = True
|
||||||
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
)
|
)
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
)
|
)
|
||||||
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
||||||
@@ -332,8 +342,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
"""Disable gradient checkpointing."""
|
"""Disable gradient checkpointing."""
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
# Call the proper gradient_checkpointing_disable() method
|
# Call the proper gradient_checkpointing_disable() method
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
|
||||||
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
||||||
|
|
||||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||||
@@ -523,7 +533,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
# Convert embeddings to bfloat16 if needed
|
# Convert embeddings to bfloat16 if needed
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -616,7 +626,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -714,7 +724,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
# Ensure correct precision (bfloat16/float32)
|
# Ensure correct precision (bfloat16/float32)
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -897,14 +907,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Load state dict (expects keys with "model." prefix)
|
||||||
try:
|
try:
|
||||||
# Try to load the pytorch_model.bin or model.safetensors file
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
try:
|
try:
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
resolved_file = cached_file(
|
||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
@@ -912,7 +920,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
@@ -925,8 +933,9 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
print("Returning model without loading pretrained weights")
|
print("Returning model without loading pretrained weights")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||||
|
|
||||||
# Then add "model." prefix for all keys that don't already have it
|
# Then add "model." prefix for all keys that don't already have it
|
||||||
remapped_state_dict = {}
|
remapped_state_dict = {}
|
||||||
remap_count = 0
|
remap_count = 0
|
||||||
@@ -936,8 +945,6 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
new_key = f"model.{key}"
|
new_key = f"model.{key}"
|
||||||
remapped_state_dict[new_key] = value
|
remapped_state_dict[new_key] = value
|
||||||
remap_count += 1
|
remap_count += 1
|
||||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
||||||
print(f"Remapped: {key} -> {new_key}")
|
|
||||||
else:
|
else:
|
||||||
remapped_state_dict[key] = value
|
remapped_state_dict[key] = value
|
||||||
|
|
||||||
@@ -971,7 +978,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
print("All keys loaded successfully!")
|
print("All keys loaded successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not remap state dict keys: {e}")
|
print(f"Warning: Could not load state dict: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ActionTokenizerProcessorStep,
|
ActionTokenizerProcessorStep,
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
|||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
|
|
||||||
# Prepare state (pad to max_state_dim)
|
|
||||||
state = pad_vector(state, self.max_state_dim)
|
|
||||||
|
|
||||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
state_np = state.cpu().numpy()
|
state_np = state.cpu().numpy()
|
||||||
|
|||||||
@@ -0,0 +1,363 @@
|
|||||||
|
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
from transformers.masking_utils import create_causal_mask
|
||||||
|
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
|
GemmaAttention,
|
||||||
|
GemmaConfig,
|
||||||
|
GemmaForCausalLM,
|
||||||
|
GemmaMLP,
|
||||||
|
GemmaModel,
|
||||||
|
)
|
||||||
|
from transformers.models.paligemma.modeling_paligemma import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
PaliGemmaModel,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
GemmaAttention = None
|
||||||
|
GemmaConfig = None
|
||||||
|
GemmaForCausalLM = None
|
||||||
|
GemmaMLP = None
|
||||||
|
GemmaModel = None
|
||||||
|
PaliGemmaModel = None
|
||||||
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
DynamicCache = None
|
||||||
|
GradientCheckpointingLayer = None
|
||||||
|
BaseModelOutputWithPast = None
|
||||||
|
create_causal_mask = None
|
||||||
|
|
||||||
|
|
||||||
|
def _gated_residual(
|
||||||
|
x: torch.Tensor | None,
|
||||||
|
y: torch.Tensor | None,
|
||||||
|
gate: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
"""Gated residual: x + y when gate is None, else x + y * gate."""
|
||||||
|
if x is None and y is None:
|
||||||
|
return None
|
||||||
|
if x is None or y is None:
|
||||||
|
return x if x is not None else y
|
||||||
|
if gate is None:
|
||||||
|
return x + y
|
||||||
|
return x + y * gate
|
||||||
|
|
||||||
|
|
||||||
|
def layernorm_forward(
|
||||||
|
layernorm: nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cond: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
call layernorm and return hidden states and gate
|
||||||
|
if cond is not None, use conditional norm
|
||||||
|
otherwise, use normal gemma norm
|
||||||
|
"""
|
||||||
|
if cond is not None:
|
||||||
|
return layernorm(x, cond=cond)
|
||||||
|
else:
|
||||||
|
return layernorm(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PiGemmaRMSNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Adaptive RMSNorm for PI Gemma (AdaRMS).
|
||||||
|
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
|
||||||
|
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.dim = dim
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
if cond_dim is not None:
|
||||||
|
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||||
|
nn.init.zeros_(self.dense.weight)
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(torch.zeros(dim))
|
||||||
|
self.dense = None
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
# Compute variance in float32 (like the source implementation)
|
||||||
|
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||||
|
# Compute normalization in float32
|
||||||
|
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||||
|
return normed_inputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cond: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
dtype = x.dtype
|
||||||
|
normed = self._norm(x)
|
||||||
|
if cond is None or self.dense is None:
|
||||||
|
normed = normed * (1.0 + self.weight.float())
|
||||||
|
return normed.type_as(x), None
|
||||||
|
if cond.shape[-1] != self.cond_dim:
|
||||||
|
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||||
|
modulation = self.dense(cond)
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
modulation = modulation.unsqueeze(1)
|
||||||
|
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||||
|
normed = normed * (1 + scale.float()) + shift.float()
|
||||||
|
return normed.to(dtype), gate.to(dtype)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
if self.dense is not None:
|
||||||
|
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
|
||||||
|
return f"dim={self.dim}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pi_gemma_decoder_layer_base():
|
||||||
|
"""base for PiGemmaDecoderLayer"""
|
||||||
|
|
||||||
|
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
|
||||||
|
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
|
||||||
|
|
||||||
|
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||||
|
self.mlp = GemmaMLP(config)
|
||||||
|
cond_dim = (
|
||||||
|
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||||
|
)
|
||||||
|
self.input_layernorm = PiGemmaRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = PiGemmaRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
position_ids: torch.LongTensor | None = None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: torch.LongTensor | None = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
adarms_cond: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
|
||||||
|
hidden_states, _ = self.self_attn(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return _PiGemmaDecoderLayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||||
|
"""
|
||||||
|
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
# if not getattr(config, "use_adarms", False):
|
||||||
|
# return
|
||||||
|
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||||
|
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
position_ids: torch.LongTensor | None = None,
|
||||||
|
past_key_values: DynamicCache | None = None,
|
||||||
|
inputs_embeds: torch.FloatTensor | None = None,
|
||||||
|
use_cache: bool | None = None,
|
||||||
|
output_attentions: bool | None = None,
|
||||||
|
output_hidden_states: bool | None = None,
|
||||||
|
cache_position: torch.LongTensor | None = None,
|
||||||
|
adarms_cond: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseModelOutputWithPast:
|
||||||
|
"""
|
||||||
|
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||||
|
Condition for ADARMS.
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if use_cache and past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = create_causal_mask(
|
||||||
|
config=self.config,
|
||||||
|
input_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
# Convert to bfloat16 if the first layer uses bfloat16
|
||||||
|
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||||
|
hidden_states = hidden_states.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# normalized
|
||||||
|
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||||
|
# See https://github.com/huggingface/transformers/pull/29402
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
adarms_cond=adarms_cond,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=past_key_values if use_cache else None,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||||
|
"""
|
||||||
|
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
|
||||||
|
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
self.model = PiGemmaModel(config)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||||
|
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.language_model = PiGemmaModel(config.text_config)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
|
||||||
|
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = PaliGemmaModelWithPiGemma(config)
|
||||||
|
|
||||||
|
# Make modules available through conditional class for BC
|
||||||
|
@property
|
||||||
|
def language_model(self):
|
||||||
|
return self.model.language_model
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PiGemmaModel",
|
||||||
|
"PiGemmaForCausalLM",
|
||||||
|
"PiGemmaRMSNorm",
|
||||||
|
"_gated_residual",
|
||||||
|
"layernorm_forward",
|
||||||
|
"PaliGemmaModelWithPiGemma",
|
||||||
|
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||||
|
]
|
||||||
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
|||||||
latent_dim: int = 256
|
latent_dim: int = 256
|
||||||
image_embedding_pooling_dim: int = 8
|
image_embedding_pooling_dim: int = 8
|
||||||
dropout_rate: float = 0.1
|
dropout_rate: float = 0.1
|
||||||
model_name: str = "helper2424/resnet10"
|
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
model_type: str = "cnn" # "transformer" or "cnn"
|
model_type: str = "cnn" # "transformer" or "cnn"
|
||||||
num_cameras: int = 2
|
num_cameras: int = 2
|
||||||
|
|||||||
@@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
and optional LoRA fine-tuning support.
|
and optional LoRA fine-tuning support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||||
config_class = Qwen2_5_VLConfig
|
config_class = Qwen2_5_VLConfig
|
||||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
|
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
if getattr(self.model, "language_model", None) is not None:
|
||||||
|
return
|
||||||
|
super().init_weights()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
processor.action_processor = action_tokenizer
|
processor.action_processor = action_tokenizer
|
||||||
else:
|
else:
|
||||||
action_tokenizer = None
|
action_tokenizer = None
|
||||||
|
|
||||||
|
# add pad_token_id to config
|
||||||
|
config.pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
|
||||||
# Initialize model with configuration and processor
|
# Initialize model with configuration and processor
|
||||||
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
||||||
|
|
||||||
@@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
|||||||
window_size=112,
|
window_size=112,
|
||||||
out_hidden_size=3584,
|
out_hidden_size=3584,
|
||||||
fullatt_block_indexes=[7, 15, 23, 31],
|
fullatt_block_indexes=[7, 15, 23, 31],
|
||||||
|
initializer_range=0.02,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
|||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.fullatt_block_indexes = fullatt_block_indexes
|
self.fullatt_block_indexes = fullatt_block_indexes
|
||||||
self.out_hidden_size = out_hidden_size
|
self.out_hidden_size = out_hidden_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.cache_utils import (
|
from transformers.cache_utils import (
|
||||||
Cache,
|
Cache,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
SlidingWindowCache,
|
|
||||||
StaticCache,
|
StaticCache,
|
||||||
)
|
)
|
||||||
from transformers.generation import GenerationMixin
|
from transformers.generation import GenerationMixin
|
||||||
@@ -31,6 +30,15 @@ from transformers.utils import (
|
|||||||
|
|
||||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
|
||||||
|
# always return False (which is the correct behavior when no sliding window cache is in use).
|
||||||
|
class _SlidingWindowCachePlaceholder:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
SlidingWindowCache = _SlidingWindowCachePlaceholder
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
@@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
|
||||||
|
"""
|
||||||
|
compute default rope parameters for Qwen2_5_VL
|
||||||
|
"""
|
||||||
|
base = config.text_config.rope_parameters["rope_theta"]
|
||||||
|
dim = config.hidden_size // config.num_attention_heads
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||||
|
)
|
||||||
|
return inv_freq, 1.0
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
|
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
|
||||||
|
self.rope_type = config.rope_parameters.get("rope_type", "default")
|
||||||
else:
|
else:
|
||||||
self.rope_type = "default"
|
self.rope_type = "default"
|
||||||
self.max_seq_len_cached = config.max_position_embeddings
|
self.max_seq_len_cached = config.max_position_embeddings
|
||||||
self.original_max_seq_len = config.max_position_embeddings
|
self.original_max_seq_len = config.max_position_embeddings
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
||||||
|
if self.rope_type == "default":
|
||||||
|
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
|
||||||
|
self.rope_kwargs = {}
|
||||||
|
else:
|
||||||
|
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
|
||||||
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
|
||||||
|
self.rope_kwargs = {}
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def preprocesser_call(
|
|||||||
"""
|
"""
|
||||||
# Process image inputs
|
# Process image inputs
|
||||||
if images is not None and len(images) > 0:
|
if images is not None and len(images) > 0:
|
||||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
|
||||||
image_grid_thw = image_inputs["image_grid_thw"]
|
image_grid_thw = image_inputs["image_grid_thw"]
|
||||||
else:
|
else:
|
||||||
image_inputs = {}
|
image_inputs = {}
|
||||||
@@ -152,7 +152,7 @@ def preprocesser_call(
|
|||||||
|
|
||||||
# Process video inputs
|
# Process video inputs
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
|
||||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||||
else:
|
else:
|
||||||
videos_inputs = {}
|
videos_inputs = {}
|
||||||
|
|||||||
@@ -13,12 +13,9 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
""" Florence-2 configuration"""
|
""" Florence-2 configuration"""
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Florence2VisionConfig(PretrainedConfig):
|
class Florence2VisionConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
@@ -276,6 +273,8 @@ class Florence2LanguageConfig(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ensure backward compatibility for BART CNN models
|
# ensure backward compatibility for BART CNN models
|
||||||
|
if not hasattr(self, "forced_bos_token_id"):
|
||||||
|
self.forced_bos_token_id = None
|
||||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||||
self.forced_bos_token_id = self.bos_token_id
|
self.forced_bos_token_id = self.bos_token_id
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from transformers.utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -57,8 +56,6 @@ if is_flash_attn_2_available():
|
|||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Florence2Config"
|
_CONFIG_FOR_DOC = "Florence2Config"
|
||||||
|
|
||||||
|
|
||||||
@@ -992,12 +989,6 @@ class Florence2FlashAttention2(Florence2Attention):
|
|||||||
else:
|
else:
|
||||||
target_dtype = self.q_proj.weight.dtype
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
logger.warning_once(
|
|
||||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
||||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
||||||
f" {target_dtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
query_states = query_states.to(target_dtype)
|
query_states = query_states.to(target_dtype)
|
||||||
key_states = key_states.to(target_dtype)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
@@ -1135,11 +1126,6 @@ class Florence2SdpaAttention(Florence2Attention):
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
if output_attentions or layer_head_mask is not None:
|
if output_attentions or layer_head_mask is not None:
|
||||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
|
||||||
logger.warning_once(
|
|
||||||
"Florence2Model is using Florence2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
|
||||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
||||||
)
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
key_value_states=key_value_states,
|
key_value_states=key_value_states,
|
||||||
@@ -1860,9 +1846,6 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
@@ -1951,7 +1934,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
_tied_weights_keys = {
|
||||||
|
"encoder.embed_tokens.weight": "shared.weight",
|
||||||
|
"decoder.embed_tokens.weight": "shared.weight",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: Florence2LanguageConfig):
|
def __init__(self, config: Florence2LanguageConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -2076,7 +2062,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
|||||||
|
|
||||||
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
_tied_weights_keys = {
|
||||||
|
"model.encoder.embed_tokens.weight": "model.shared.weight",
|
||||||
|
"model.decoder.embed_tokens.weight": "model.shared.weight",
|
||||||
|
}
|
||||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||||
|
|
||||||
def __init__(self, config: Florence2LanguageConfig):
|
def __init__(self, config: Florence2LanguageConfig):
|
||||||
@@ -2154,8 +2143,6 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if use_cache:
|
|
||||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
|
||||||
use_cache = False
|
use_cache = False
|
||||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
@@ -2436,11 +2423,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
|||||||
FLORENCE2_START_DOCSTRING,
|
FLORENCE2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||||
_tied_weights_keys = [
|
_tied_weights_keys = {
|
||||||
"language_model.encoder.embed_tokens.weight",
|
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||||
"language_model.decoder.embed_tokens.weight",
|
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||||
"language_model.lm_head.weight",
|
}
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, config: Florence2Config):
|
def __init__(self, config: Florence2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
Requires the `transformers` library to be installed.
|
Requires the `transformers` library to be installed.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ def train_fast_tokenizer(
|
|||||||
|
|
||||||
# download the tokenizer source code (not pretrained weights)
|
# download the tokenizer source code (not pretrained weights)
|
||||||
# we'll train a new tokenizer on our own data
|
# we'll train a new tokenizer on our own data
|
||||||
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True)
|
||||||
|
|
||||||
# convert action_chunks array to list of arrays (expected by .fit())
|
# convert action_chunks array to list of arrays (expected by .fit())
|
||||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
@@ -37,6 +38,9 @@ def test_classifier_output():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_binary_classifier_with_default_params():
|
def test_binary_classifier_with_default_params():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -78,6 +82,9 @@ def test_binary_classifier_with_default_params():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_multiclass_classifier():
|
def test_multiclass_classifier():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -117,6 +124,9 @@ def test_multiclass_classifier():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_default_device():
|
def test_default_device():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -129,6 +139,9 @@ def test_default_device():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_explicit_device_setup():
|
def test_explicit_device_setup():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
"""Test script to verify PI0Fast policy integration with LeRobot vs the original implementation"""
|
"""Test script to verify PI0Fast policy integration with LeRobot vs the original implementation"""
|
||||||
# ruff: noqa: E402
|
# ruff: noqa: E402
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -28,10 +27,6 @@ import torch
|
|||||||
|
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
pytest.importorskip("scipy")
|
pytest.importorskip("scipy")
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="This test requires accepting the model license",
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
|
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
|
||||||
@@ -54,19 +49,19 @@ IMAGE_HEIGHT = 224
|
|||||||
IMAGE_WIDTH = 224
|
IMAGE_WIDTH = 224
|
||||||
NUM_VIEWS = 2 # Number of camera views
|
NUM_VIEWS = 2 # Number of camera views
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
MODEL_PATH_LEROBOT = "lerobot/pi0fast-base"
|
MODEL_PATH_LEROBOT = "jadechoghari/pi0fast-base"
|
||||||
|
|
||||||
# Expected action token shape: (batch_size, max_decoding_steps)
|
# Expected action token shape: (batch_size, max_decoding_steps)
|
||||||
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
|
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
|
||||||
|
|
||||||
# Expected first 5 action tokens (for reproducibility check)
|
# Expected first 5 action tokens (for reproducibility check)
|
||||||
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362])
|
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255425])
|
||||||
|
|
||||||
# Expected actions after detokenization
|
# Expected actions after detokenization
|
||||||
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
|
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
|
||||||
EXPECTED_ACTIONS_MEAN = 0.04419417306780815
|
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
|
||||||
EXPECTED_ACTIONS_STD = 0.26231569051742554
|
EXPECTED_ACTIONS_STD = 0.2607129216194153
|
||||||
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000])
|
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([-0.0707, 1.4849, 0.0000, 0.0000, 0.0000])
|
||||||
|
|
||||||
|
|
||||||
def set_seed_all(seed: int):
|
def set_seed_all(seed: int):
|
||||||
|
|||||||
@@ -16,17 +16,8 @@
|
|||||||
|
|
||||||
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
|
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Skip this entire module in CI
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
from lerobot.policies.pi0 import ( # noqa: E402
|
from lerobot.policies.pi0 import ( # noqa: E402
|
||||||
PI0Config,
|
PI0Config,
|
||||||
|
|||||||
@@ -16,25 +16,15 @@
|
|||||||
|
|
||||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
|
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.utils.random_utils import set_seed
|
|
||||||
|
|
||||||
# Skip this entire module in CI
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
from lerobot.policies.pi05 import ( # noqa: E402
|
from lerobot.policies.pi05 import ( # noqa: E402
|
||||||
PI05Config,
|
PI05Config,
|
||||||
PI05Policy,
|
PI05Policy,
|
||||||
make_pi05_pre_post_processors, # noqa: E402
|
make_pi05_pre_post_processors, # noqa: E402
|
||||||
)
|
)
|
||||||
|
from lerobot.utils.random_utils import set_seed
|
||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import torch
|
|||||||
# Skip this entire module in CI
|
# Skip this entire module in CI
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
reason="TODO: This test seems to hang the CI",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||||
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
|
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import torch
|
|||||||
# Skip this entire module in CI
|
# Skip this entire module in CI
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
reason="TODO: This test seems to hang the CI",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||||
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
|
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||||
|
|||||||
@@ -305,6 +305,9 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
|
|||||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||||
)
|
)
|
||||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_sac_policy_with_pretrained_encoder(
|
def test_sac_policy_with_pretrained_encoder(
|
||||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
|
|
||||||
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -26,12 +24,6 @@ pytest.importorskip("peft")
|
|||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
pytest.importorskip("torchdiffeq")
|
pytest.importorskip("torchdiffeq")
|
||||||
|
|
||||||
# Skip this entire module in CI
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="This test requires local Wall-X installation and is not meant for CI",
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
from lerobot.policies.wall_x import WallXConfig # noqa: E402
|
from lerobot.policies.wall_x import WallXConfig # noqa: E402
|
||||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
|
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
|
||||||
|
|||||||
Reference in New Issue
Block a user