mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
Merge branch 'chore/bump_transformers_v5' into ci/add_hf_account
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
|||||||
|
|
||||||
You have two options for the FAST tokenizer:
|
You have two options for the FAST tokenizer:
|
||||||
|
|
||||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||||
|
|
||||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||||
|
|
||||||
@@ -115,13 +115,13 @@ lerobot-train \
|
|||||||
### Key Training Parameters
|
### Key Training Parameters
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
|
||||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
|
||||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||||
|
|
||||||
## Inference
|
## Inference
|
||||||
|
|||||||
+11
-93
@@ -61,7 +61,7 @@ dependencies = [
|
|||||||
# Hugging Face dependencies
|
# Hugging Face dependencies
|
||||||
"datasets>=4.0.0,<5.0.0",
|
"datasets>=4.0.0,<5.0.0",
|
||||||
"diffusers>=0.27.2,<0.36.0",
|
"diffusers>=0.27.2,<0.36.0",
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
"huggingface-hub[cli]>=1.0.0,<2.0.0",
|
||||||
"accelerate>=1.10.0,<2.0.0",
|
"accelerate>=1.10.0,<2.0.0",
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
@@ -96,7 +96,7 @@ dependencies = [
|
|||||||
# Common
|
# Common
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
transformers-dep = ["transformers>=5.1.0,<6.0.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
|
|
||||||
@@ -129,13 +129,13 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
|||||||
|
|
||||||
# Policies
|
# Policies
|
||||||
wallx = [
|
wallx = [
|
||||||
"transformers==4.49.0",
|
"lerobot[transformers-dep]",
|
||||||
"peft==0.17.1",
|
"peft>=0.18.0,<1.0.0",
|
||||||
"scipy==1.15.3",
|
"scipy==1.15.3", # TODO: Relax version
|
||||||
"torchdiffeq==0.2.5",
|
"torchdiffeq==0.2.5", # TODO: Relax version
|
||||||
"qwen_vl_utils==0.0.11"
|
"qwen-vl-utils==0.0.11" # TODO: Relax version
|
||||||
]
|
]
|
||||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
pi = ["lerobot[transformers-dep]", "scipy==1.15.3"] # TODO: Relax scipy version
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||||
groot = [
|
groot = [
|
||||||
"lerobot[transformers-dep]",
|
"lerobot[transformers-dep]",
|
||||||
@@ -148,7 +148,7 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
@@ -176,8 +176,8 @@ all = [
|
|||||||
"lerobot[reachy2]",
|
"lerobot[reachy2]",
|
||||||
"lerobot[kinematics]",
|
"lerobot[kinematics]",
|
||||||
"lerobot[intelrealsense]",
|
"lerobot[intelrealsense]",
|
||||||
# "lerobot[wallx]",
|
"lerobot[wallx]",
|
||||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
"lerobot[pi]",
|
||||||
"lerobot[smolvla]",
|
"lerobot[smolvla]",
|
||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
"lerobot[xvla]",
|
"lerobot[xvla]",
|
||||||
@@ -394,85 +394,3 @@ ignore_errors = false
|
|||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.scripts.*"
|
# module = "lerobot.scripts.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
[tool.uv]
|
|
||||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
|
||||||
conflicts = [
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "transformers-dep" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "pi" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "smolvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "groot" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "xvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "sarm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "hilserl" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "libero" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "peft" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "wallx" },
|
|
||||||
{ extra = "all" },
|
|
||||||
],
|
|
||||||
# pi uses custom branch which conflicts with transformers-dep
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "transformers-dep" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "smolvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "groot" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "xvla" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "sarm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "hilserl" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "libero" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "peft" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "pi" },
|
|
||||||
{ extra = "all" },
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from transformers.image_processing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.image_processing_utils_fast import (
|
from transformers.image_processing_utils_fast import (
|
||||||
BaseImageProcessorFast,
|
BaseImageProcessorFast,
|
||||||
DefaultFastImageProcessorKwargs,
|
ImagesKwargs,
|
||||||
group_images_by_shape,
|
group_images_by_shape,
|
||||||
reorder_images,
|
reorder_images,
|
||||||
)
|
)
|
||||||
@@ -77,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
|
|||||||
return img[:, top:bottom, left:right]
|
return img[:, top:bottom, left:right]
|
||||||
|
|
||||||
|
|
||||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||||
max_dynamic_tiles: int | None
|
max_dynamic_tiles: int | None
|
||||||
min_dynamic_tiles: int | None
|
min_dynamic_tiles: int | None
|
||||||
use_thumbnail: bool | None
|
use_thumbnail: bool | None
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -32,13 +33,21 @@ from lerobot.utils.import_utils import _transformers_available
|
|||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
||||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaForCausalLM,
|
||||||
|
_gated_residual,
|
||||||
|
layernorm_forward,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
GemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
PaliGemmaForConditionalGeneration = None
|
_gated_residual = None
|
||||||
|
layernorm_forward = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||||
@@ -191,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype == torch.float32:
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(0.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
|
|
||||||
@@ -202,7 +211,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
pad_w1 = pad_w0 + remainder_w
|
pad_w1 = pad_w0 + remainder_w
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||||
padded_images = F.pad(
|
padded_images = F.pad(
|
||||||
resized_images,
|
resized_images,
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||||
@@ -221,14 +230,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
def compute_layer_complete(
|
def compute_layer_complete(
|
||||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
):
|
):
|
||||||
models = [paligemma.language_model, gemma_expert.model]
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = models[i].layers[layer_idx]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||||
@@ -254,10 +263,10 @@ def compute_layer_complete(
|
|||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
# Attention computation
|
# Attention computation
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
paligemma.language_model.layers[layer_idx].self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -265,7 +274,7 @@ def compute_layer_complete(
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
@@ -277,15 +286,15 @@ def compute_layer_complete(
|
|||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||||
# first residual
|
# first residual
|
||||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||||
after_first_residual = out_emb.clone()
|
after_first_residual = out_emb.clone()
|
||||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||||
out_emb = layer.mlp(out_emb)
|
out_emb = layer.mlp(out_emb)
|
||||||
# second residual
|
# second residual
|
||||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
start_pos = end_pos
|
start_pos = end_pos
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
@@ -358,7 +367,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
vlm_config_hf.text_config.dtype = "float32"
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
@@ -366,7 +375,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
vlm_config_hf.vision_config.dtype = "float32"
|
||||||
|
|
||||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||||
head_dim=action_expert_config.head_dim,
|
head_dim=action_expert_config.head_dim,
|
||||||
@@ -377,13 +386,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||||
vocab_size=257152,
|
vocab_size=257152,
|
||||||
hidden_activation="gelu_pytorch_tanh",
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
torch_dtype="float32",
|
dtype="float32",
|
||||||
use_adarms=use_adarms[1],
|
use_adarms=use_adarms[1],
|
||||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||||
self.gemma_expert.model.embed_tokens = None
|
self.gemma_expert.model.embed_tokens = None
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
@@ -398,10 +407,11 @@ class PaliGemmaWithExpertModel(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
raise ValueError(f"Invalid precision: {precision}")
|
||||||
|
|
||||||
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||||
|
# "same dtype" error). Align with PI05.
|
||||||
params_to_keep_float32 = [
|
params_to_keep_float32 = [
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
"vision_tower",
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
"multi_modal_projector",
|
||||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
||||||
"input_layernorm",
|
"input_layernorm",
|
||||||
"post_attention_layernorm",
|
"post_attention_layernorm",
|
||||||
"model.norm",
|
"model.norm",
|
||||||
@@ -413,8 +423,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
|
|
||||||
def _set_requires_grad(self):
|
def _set_requires_grad(self):
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
for param in self.paligemma.vision_tower.parameters():
|
for param in self.paligemma.model.vision_tower.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
@@ -424,15 +434,23 @@ class PaliGemmaWithExpertModel(
|
|||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
super().train(mode)
|
super().train(mode)
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
|
if features.dtype != out_dtype:
|
||||||
|
features = features.to(out_dtype)
|
||||||
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -446,7 +464,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
if adarms_cond is None:
|
if adarms_cond is None:
|
||||||
adarms_cond = [None, None]
|
adarms_cond = [None, None]
|
||||||
if inputs_embeds[1] is None:
|
if inputs_embeds[1] is None:
|
||||||
prefix_output = self.paligemma.language_model.forward(
|
prefix_output = self.paligemma.model.language_model.forward(
|
||||||
inputs_embeds=inputs_embeds[0],
|
inputs_embeds=inputs_embeds[0],
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -470,7 +488,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
@@ -510,7 +528,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -576,29 +594,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Also compile the main forward pass used during training
|
# Also compile the main forward pass used during training
|
||||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.siglip import check
|
|
||||||
|
|
||||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(msg) from None
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
def gradient_checkpointing_enable(self):
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
"""Enable gradient checkpointing for memory optimization."""
|
||||||
self.gradient_checkpointing_enabled = True
|
self.gradient_checkpointing_enabled = True
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
def gradient_checkpointing_disable(self):
|
||||||
"""Disable gradient checkpointing."""
|
"""Disable gradient checkpointing."""
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||||
|
|
||||||
@@ -760,7 +768,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -834,7 +842,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
_, past_key_values = self.paligemma_with_expert.forward(
|
||||||
attention_mask=prefix_att_2d_masks_4d,
|
attention_mask=prefix_att_2d_masks_4d,
|
||||||
@@ -908,6 +916,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -997,14 +1006,12 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Load state dict (expects keys with "model." prefix)
|
||||||
try:
|
try:
|
||||||
# Try to load the pytorch_model.bin or model.safetensors file
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
try:
|
try:
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
resolved_file = cached_file(
|
||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
@@ -1012,7 +1019,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
@@ -1025,7 +1032,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
print("Returning model without loading pretrained weights")
|
print("Returning model without loading pretrained weights")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||||
|
|
||||||
# Then add "model." prefix for all keys that don't already have it
|
# Then add "model." prefix for all keys that don't already have it
|
||||||
@@ -1070,7 +1077,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
print("All keys loaded successfully!")
|
print("All keys loaded successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not remap state dict keys: {e}")
|
print(f"Warning: Could not load state dict: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -1120,6 +1127,14 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
# Some checkpoints might have this, but current model expects different structure
|
# Some checkpoints might have this, but current model expects different structure
|
||||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
):
|
||||||
|
fixed_state_dict[
|
||||||
|
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||||
|
] = value.clone()
|
||||||
|
|
||||||
fixed_state_dict[new_key] = value
|
fixed_state_dict[new_key] = value
|
||||||
|
|
||||||
return fixed_state_dict
|
return fixed_state_dict
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -32,14 +33,20 @@ from lerobot.utils.import_utils import _transformers_available
|
|||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
||||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaForCausalLM,
|
||||||
|
_gated_residual,
|
||||||
|
layernorm_forward,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
GemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
PaliGemmaForConditionalGeneration = None
|
_gated_residual = None
|
||||||
|
layernorm_forward = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
@@ -189,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype == torch.float32:
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(0.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
|
|
||||||
@@ -200,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
pad_w1 = pad_w0 + remainder_w
|
pad_w1 = pad_w0 + remainder_w
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||||
padded_images = F.pad(
|
padded_images = F.pad(
|
||||||
resized_images,
|
resized_images,
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||||
@@ -219,14 +226,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
def compute_layer_complete(
|
def compute_layer_complete(
|
||||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
):
|
):
|
||||||
models = [paligemma.language_model, gemma_expert.model]
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = models[i].layers[layer_idx]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||||
@@ -252,10 +259,10 @@ def compute_layer_complete(
|
|||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
# Attention computation
|
# Attention computation
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
paligemma.language_model.layers[layer_idx].self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -263,7 +270,7 @@ def compute_layer_complete(
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
@@ -275,15 +282,15 @@ def compute_layer_complete(
|
|||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||||
# first residual
|
# first residual
|
||||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||||
after_first_residual = out_emb.clone()
|
after_first_residual = out_emb.clone()
|
||||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||||
out_emb = layer.mlp(out_emb)
|
out_emb = layer.mlp(out_emb)
|
||||||
# second residual
|
# second residual
|
||||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
start_pos = end_pos
|
start_pos = end_pos
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
@@ -356,7 +363,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
vlm_config_hf.text_config.dtype = "float32"
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
@@ -364,7 +371,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
vlm_config_hf.vision_config.dtype = "float32"
|
||||||
|
|
||||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||||
head_dim=action_expert_config.head_dim,
|
head_dim=action_expert_config.head_dim,
|
||||||
@@ -375,13 +382,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||||
vocab_size=257152,
|
vocab_size=257152,
|
||||||
hidden_activation="gelu_pytorch_tanh",
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
torch_dtype="float32",
|
dtype="float32",
|
||||||
use_adarms=use_adarms[1],
|
use_adarms=use_adarms[1],
|
||||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||||
self.gemma_expert.model.embed_tokens = None
|
self.gemma_expert.model.embed_tokens = None
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
@@ -396,10 +403,11 @@ class PaliGemmaWithExpertModel(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
raise ValueError(f"Invalid precision: {precision}")
|
||||||
|
|
||||||
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||||
|
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||||
params_to_keep_float32 = [
|
params_to_keep_float32 = [
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
"vision_tower",
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
"multi_modal_projector",
|
||||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
||||||
"input_layernorm",
|
"input_layernorm",
|
||||||
"post_attention_layernorm",
|
"post_attention_layernorm",
|
||||||
"model.norm",
|
"model.norm",
|
||||||
@@ -411,8 +419,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
|
|
||||||
def _set_requires_grad(self):
|
def _set_requires_grad(self):
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
for param in self.paligemma.vision_tower.parameters():
|
for param in self.paligemma.model.vision_tower.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
@@ -422,15 +430,23 @@ class PaliGemmaWithExpertModel(
|
|||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
super().train(mode)
|
super().train(mode)
|
||||||
if self.freeze_vision_encoder:
|
if self.freeze_vision_encoder:
|
||||||
self.paligemma.vision_tower.eval()
|
self.paligemma.model.vision_tower.eval()
|
||||||
if self.train_expert_only:
|
if self.train_expert_only:
|
||||||
self.paligemma.eval()
|
self.paligemma.eval()
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
|
if features.dtype != out_dtype:
|
||||||
|
features = features.to(out_dtype)
|
||||||
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -444,7 +460,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
if adarms_cond is None:
|
if adarms_cond is None:
|
||||||
adarms_cond = [None, None]
|
adarms_cond = [None, None]
|
||||||
if inputs_embeds[1] is None:
|
if inputs_embeds[1] is None:
|
||||||
prefix_output = self.paligemma.language_model.forward(
|
prefix_output = self.paligemma.model.language_model.forward(
|
||||||
inputs_embeds=inputs_embeds[0],
|
inputs_embeds=inputs_embeds[0],
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -468,7 +484,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
@@ -508,7 +524,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -573,29 +589,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Also compile the main forward pass used during training
|
# Also compile the main forward pass used during training
|
||||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.siglip import check
|
|
||||||
|
|
||||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(msg) from None
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
def gradient_checkpointing_enable(self):
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
"""Enable gradient checkpointing for memory optimization."""
|
||||||
self.gradient_checkpointing_enabled = True
|
self.gradient_checkpointing_enabled = True
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
def gradient_checkpointing_disable(self):
|
||||||
"""Disable gradient checkpointing."""
|
"""Disable gradient checkpointing."""
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
@@ -737,7 +743,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -808,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
_, past_key_values = self.paligemma_with_expert.forward(
|
||||||
attention_mask=prefix_att_2d_masks_4d,
|
attention_mask=prefix_att_2d_masks_4d,
|
||||||
@@ -880,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -969,14 +976,12 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Load state dict (expects keys with "model." prefix)
|
||||||
try:
|
try:
|
||||||
# Try to load the pytorch_model.bin or model.safetensors file
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
try:
|
try:
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
resolved_file = cached_file(
|
||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
@@ -984,7 +989,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
@@ -997,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
print("Returning model without loading pretrained weights")
|
print("Returning model without loading pretrained weights")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||||
|
|
||||||
# Then add "model." prefix for all keys that don't already have it
|
# Then add "model." prefix for all keys that don't already have it
|
||||||
@@ -1009,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
new_key = f"model.{key}"
|
new_key = f"model.{key}"
|
||||||
remapped_state_dict[new_key] = value
|
remapped_state_dict[new_key] = value
|
||||||
remap_count += 1
|
remap_count += 1
|
||||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
||||||
print(f"Remapped: {key} -> {new_key}")
|
|
||||||
else:
|
else:
|
||||||
remapped_state_dict[key] = value
|
remapped_state_dict[key] = value
|
||||||
|
|
||||||
@@ -1044,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
print("All keys loaded successfully!")
|
print("All keys loaded successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not remap state dict keys: {e}")
|
print(f"Warning: Could not load state dict: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -1098,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
# Some checkpoints might have this, but current model expects different structure
|
# Some checkpoints might have this, but current model expects different structure
|
||||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||||
|
):
|
||||||
|
fixed_state_dict[
|
||||||
|
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||||
|
] = value.clone()
|
||||||
|
|
||||||
fixed_state_dict[new_key] = value
|
fixed_state_dict[new_key] = value
|
||||||
|
|
||||||
return fixed_state_dict
|
return fixed_state_dict
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
|
|
||||||
# Prepare state (pad to max_state_dim)
|
|
||||||
state = pad_vector(state, self.max_state_dim)
|
|
||||||
|
|
||||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
state_np = state.cpu().numpy()
|
state_np = state.cpu().numpy()
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
max_decoding_steps: int = 256
|
max_decoding_steps: int = 256
|
||||||
fast_skip_tokens: int = 128
|
fast_skip_tokens: int = 128
|
||||||
|
|||||||
@@ -38,11 +38,16 @@ else:
|
|||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
||||||
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaModel,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
PaliGemmaForConditionalGeneration = None
|
|
||||||
AutoTokenizer = None
|
AutoTokenizer = None
|
||||||
|
PiGemmaModel = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||||
@@ -121,7 +126,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype == torch.float32:
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(0.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
|
|
||||||
@@ -132,7 +137,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
pad_w1 = pad_w0 + remainder_w
|
pad_w1 = pad_w0 + remainder_w
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||||
padded_images = F.pad(
|
padded_images = F.pad(
|
||||||
resized_images,
|
resized_images,
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||||
@@ -206,16 +211,22 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
vlm_config_hf.text_config.dtype = "float32"
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
vlm_config_hf.vision_config.dtype = "float32"
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||||
|
|
||||||
|
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
|
||||||
|
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||||
|
if use_adarms[0]:
|
||||||
|
text_config = self.paligemma.config.text_config
|
||||||
|
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
|
|
||||||
@@ -228,10 +239,11 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
raise ValueError(f"Invalid precision: {precision}")
|
||||||
|
|
||||||
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||||
|
# "same dtype" error). Align with PI05.
|
||||||
params_to_keep_float32 = [
|
params_to_keep_float32 = [
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
"vision_tower",
|
||||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
"multi_modal_projector",
|
||||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
||||||
"input_layernorm",
|
"input_layernorm",
|
||||||
"post_attention_layernorm",
|
"post_attention_layernorm",
|
||||||
"model.norm",
|
"model.norm",
|
||||||
@@ -242,10 +254,18 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
param.data = param.data.to(dtype=torch.float32)
|
param.data = param.data.to(dtype=torch.float32)
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
|
if features.dtype != out_dtype:
|
||||||
|
features = features.to(out_dtype)
|
||||||
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -259,7 +279,7 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
if adarms_cond is None:
|
if adarms_cond is None:
|
||||||
adarms_cond = [None, None]
|
adarms_cond = [None, None]
|
||||||
if inputs_embeds[1] is None:
|
if inputs_embeds[1] is None:
|
||||||
prefix_output = self.paligemma.language_model.forward(
|
prefix_output = self.paligemma.model.language_model.forward(
|
||||||
inputs_embeds=inputs_embeds[0],
|
inputs_embeds=inputs_embeds[0],
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -306,24 +326,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
||||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.siglip import check
|
|
||||||
|
|
||||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(msg) from None
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
def gradient_checkpointing_enable(self):
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
"""Enable gradient checkpointing for memory optimization."""
|
||||||
self.gradient_checkpointing_enabled = True
|
self.gradient_checkpointing_enabled = True
|
||||||
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
)
|
)
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
)
|
)
|
||||||
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
||||||
@@ -332,8 +342,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
"""Disable gradient checkpointing."""
|
"""Disable gradient checkpointing."""
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
# Call the proper gradient_checkpointing_disable() method
|
# Call the proper gradient_checkpointing_disable() method
|
||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
|
||||||
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
||||||
|
|
||||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||||
@@ -523,7 +533,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
# Convert embeddings to bfloat16 if needed
|
# Convert embeddings to bfloat16 if needed
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -616,7 +626,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -714,7 +724,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
# Ensure correct precision (bfloat16/float32)
|
# Ensure correct precision (bfloat16/float32)
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
@@ -897,14 +907,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Load state dict (expects keys with "model." prefix)
|
||||||
try:
|
try:
|
||||||
# Try to load the pytorch_model.bin or model.safetensors file
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
try:
|
try:
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
resolved_file = cached_file(
|
||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
@@ -912,7 +920,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
@@ -925,8 +933,9 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
print("Returning model without loading pretrained weights")
|
print("Returning model without loading pretrained weights")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||||
|
|
||||||
# Then add "model." prefix for all keys that don't already have it
|
# Then add "model." prefix for all keys that don't already have it
|
||||||
remapped_state_dict = {}
|
remapped_state_dict = {}
|
||||||
remap_count = 0
|
remap_count = 0
|
||||||
@@ -936,8 +945,6 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
new_key = f"model.{key}"
|
new_key = f"model.{key}"
|
||||||
remapped_state_dict[new_key] = value
|
remapped_state_dict[new_key] = value
|
||||||
remap_count += 1
|
remap_count += 1
|
||||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
||||||
print(f"Remapped: {key} -> {new_key}")
|
|
||||||
else:
|
else:
|
||||||
remapped_state_dict[key] = value
|
remapped_state_dict[key] = value
|
||||||
|
|
||||||
@@ -971,7 +978,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
print("All keys loaded successfully!")
|
print("All keys loaded successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not remap state dict keys: {e}")
|
print(f"Warning: Could not load state dict: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ActionTokenizerProcessorStep,
|
ActionTokenizerProcessorStep,
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
|||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
|
|
||||||
# Prepare state (pad to max_state_dim)
|
|
||||||
state = pad_vector(state, self.max_state_dim)
|
|
||||||
|
|
||||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
state_np = state.cpu().numpy()
|
state_np = state.cpu().numpy()
|
||||||
|
|||||||
@@ -0,0 +1,363 @@
|
|||||||
|
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
from transformers.masking_utils import create_causal_mask
|
||||||
|
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
|
GemmaAttention,
|
||||||
|
GemmaConfig,
|
||||||
|
GemmaForCausalLM,
|
||||||
|
GemmaMLP,
|
||||||
|
GemmaModel,
|
||||||
|
)
|
||||||
|
from transformers.models.paligemma.modeling_paligemma import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
PaliGemmaModel,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
GemmaAttention = None
|
||||||
|
GemmaConfig = None
|
||||||
|
GemmaForCausalLM = None
|
||||||
|
GemmaMLP = None
|
||||||
|
GemmaModel = None
|
||||||
|
PaliGemmaModel = None
|
||||||
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
DynamicCache = None
|
||||||
|
GradientCheckpointingLayer = None
|
||||||
|
BaseModelOutputWithPast = None
|
||||||
|
create_causal_mask = None
|
||||||
|
|
||||||
|
|
||||||
|
def _gated_residual(
|
||||||
|
x: torch.Tensor | None,
|
||||||
|
y: torch.Tensor | None,
|
||||||
|
gate: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
"""Gated residual: x + y when gate is None, else x + y * gate."""
|
||||||
|
if x is None and y is None:
|
||||||
|
return None
|
||||||
|
if x is None or y is None:
|
||||||
|
return x if x is not None else y
|
||||||
|
if gate is None:
|
||||||
|
return x + y
|
||||||
|
return x + y * gate
|
||||||
|
|
||||||
|
|
||||||
|
def layernorm_forward(
|
||||||
|
layernorm: nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cond: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
call layernorm and return hidden states and gate
|
||||||
|
if cond is not None, use conditional norm
|
||||||
|
otherwise, use normal gemma norm
|
||||||
|
"""
|
||||||
|
if cond is not None:
|
||||||
|
return layernorm(x, cond=cond)
|
||||||
|
else:
|
||||||
|
return layernorm(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PiGemmaRMSNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Adaptive RMSNorm for PI Gemma (AdaRMS).
|
||||||
|
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
|
||||||
|
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.dim = dim
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
if cond_dim is not None:
|
||||||
|
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||||
|
nn.init.zeros_(self.dense.weight)
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(torch.zeros(dim))
|
||||||
|
self.dense = None
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
# Compute variance in float32 (like the source implementation)
|
||||||
|
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||||
|
# Compute normalization in float32
|
||||||
|
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||||
|
return normed_inputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cond: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
dtype = x.dtype
|
||||||
|
normed = self._norm(x)
|
||||||
|
if cond is None or self.dense is None:
|
||||||
|
normed = normed * (1.0 + self.weight.float())
|
||||||
|
return normed.type_as(x), None
|
||||||
|
if cond.shape[-1] != self.cond_dim:
|
||||||
|
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||||
|
modulation = self.dense(cond)
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
modulation = modulation.unsqueeze(1)
|
||||||
|
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||||
|
normed = normed * (1 + scale.float()) + shift.float()
|
||||||
|
return normed.to(dtype), gate.to(dtype)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
if self.dense is not None:
|
||||||
|
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
|
||||||
|
return f"dim={self.dim}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pi_gemma_decoder_layer_base():
|
||||||
|
"""base for PiGemmaDecoderLayer"""
|
||||||
|
|
||||||
|
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
|
||||||
|
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
|
||||||
|
|
||||||
|
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||||
|
self.mlp = GemmaMLP(config)
|
||||||
|
cond_dim = (
|
||||||
|
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||||
|
)
|
||||||
|
self.input_layernorm = PiGemmaRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = PiGemmaRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
position_ids: torch.LongTensor | None = None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: torch.LongTensor | None = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
adarms_cond: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
|
||||||
|
hidden_states, _ = self.self_attn(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return _PiGemmaDecoderLayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||||
|
"""
|
||||||
|
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
# if not getattr(config, "use_adarms", False):
|
||||||
|
# return
|
||||||
|
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||||
|
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
position_ids: torch.LongTensor | None = None,
|
||||||
|
past_key_values: DynamicCache | None = None,
|
||||||
|
inputs_embeds: torch.FloatTensor | None = None,
|
||||||
|
use_cache: bool | None = None,
|
||||||
|
output_attentions: bool | None = None,
|
||||||
|
output_hidden_states: bool | None = None,
|
||||||
|
cache_position: torch.LongTensor | None = None,
|
||||||
|
adarms_cond: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseModelOutputWithPast:
|
||||||
|
"""
|
||||||
|
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||||
|
Condition for ADARMS.
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if use_cache and past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = create_causal_mask(
|
||||||
|
config=self.config,
|
||||||
|
input_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
# Convert to bfloat16 if the first layer uses bfloat16
|
||||||
|
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||||
|
hidden_states = hidden_states.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# normalized
|
||||||
|
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||||
|
# See https://github.com/huggingface/transformers/pull/29402
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
adarms_cond=adarms_cond,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=past_key_values if use_cache else None,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||||
|
"""
|
||||||
|
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
|
||||||
|
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
self.model = PiGemmaModel(config)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||||
|
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.language_model = PiGemmaModel(config.text_config)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
|
||||||
|
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = PaliGemmaModelWithPiGemma(config)
|
||||||
|
|
||||||
|
# Make modules available through conditional class for BC
|
||||||
|
@property
|
||||||
|
def language_model(self):
|
||||||
|
return self.model.language_model
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PiGemmaModel",
|
||||||
|
"PiGemmaForCausalLM",
|
||||||
|
"PiGemmaRMSNorm",
|
||||||
|
"_gated_residual",
|
||||||
|
"layernorm_forward",
|
||||||
|
"PaliGemmaModelWithPiGemma",
|
||||||
|
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||||
|
]
|
||||||
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
|||||||
latent_dim: int = 256
|
latent_dim: int = 256
|
||||||
image_embedding_pooling_dim: int = 8
|
image_embedding_pooling_dim: int = 8
|
||||||
dropout_rate: float = 0.1
|
dropout_rate: float = 0.1
|
||||||
model_name: str = "helper2424/resnet10"
|
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
model_type: str = "cnn" # "transformer" or "cnn"
|
model_type: str = "cnn" # "transformer" or "cnn"
|
||||||
num_cameras: int = 2
|
num_cameras: int = 2
|
||||||
|
|||||||
@@ -331,7 +331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
force_download=kwargs.get("force_download", False),
|
force_download=kwargs.get("force_download", False),
|
||||||
resume_download=kwargs.get("resume_download"),
|
resume_download=kwargs.get("resume_download"),
|
||||||
proxies=kwargs.get("proxies"),
|
proxies=kwargs.get("proxies"),
|
||||||
use_auth_token=kwargs.get("use_auth_token"),
|
token=kwargs.get("token"),
|
||||||
revision=kwargs.get("revision"),
|
revision=kwargs.get("revision"),
|
||||||
local_files_only=kwargs.get("local_files_only", False),
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.cache_utils import (
|
from transformers.cache_utils import (
|
||||||
Cache,
|
Cache,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
SlidingWindowCache,
|
|
||||||
StaticCache,
|
StaticCache,
|
||||||
)
|
)
|
||||||
from transformers.generation import GenerationMixin
|
from transformers.generation import GenerationMixin
|
||||||
@@ -31,6 +30,15 @@ from transformers.utils import (
|
|||||||
|
|
||||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
|
||||||
|
# always return False (which is the correct behavior when no sliding window cache is in use).
|
||||||
|
class _SlidingWindowCachePlaceholder:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
SlidingWindowCache = _SlidingWindowCachePlaceholder
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|||||||
@@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ensure backward compatibility for BART CNN models
|
# ensure backward compatibility for BART CNN models
|
||||||
|
if not hasattr(self, "forced_bos_token_id"):
|
||||||
|
self.forced_bos_token_id = None
|
||||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||||
self.forced_bos_token_id = self.bos_token_id
|
self.forced_bos_token_id = self.bos_token_id
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|||||||
@@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
_tied_weights_keys = {
|
||||||
|
"encoder.embed_tokens.weight": "shared.weight",
|
||||||
|
"decoder.embed_tokens.weight": "shared.weight",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: Florence2LanguageConfig):
|
def __init__(self, config: Florence2LanguageConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
|||||||
|
|
||||||
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
_tied_weights_keys = {
|
||||||
|
"model.encoder.embed_tokens.weight": "model.shared.weight",
|
||||||
|
"model.decoder.embed_tokens.weight": "model.shared.weight",
|
||||||
|
}
|
||||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||||
|
|
||||||
def __init__(self, config: Florence2LanguageConfig):
|
def __init__(self, config: Florence2LanguageConfig):
|
||||||
@@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
|||||||
FLORENCE2_START_DOCSTRING,
|
FLORENCE2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||||
_tied_weights_keys = [
|
_tied_weights_keys = {
|
||||||
"language_model.encoder.embed_tokens.weight",
|
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||||
"language_model.decoder.embed_tokens.weight",
|
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||||
"language_model.lm_head.weight",
|
}
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, config: Florence2Config):
|
def __init__(self, config: Florence2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
Requires the `transformers` library to be installed.
|
Requires the `transformers` library to be installed.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ def train_fast_tokenizer(
|
|||||||
|
|
||||||
# download the tokenizer source code (not pretrained weights)
|
# download the tokenizer source code (not pretrained weights)
|
||||||
# we'll train a new tokenizer on our own data
|
# we'll train a new tokenizer on our own data
|
||||||
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True)
|
||||||
|
|
||||||
# convert action_chunks array to list of arrays (expected by .fit())
|
# convert action_chunks array to list of arrays (expected by .fit())
|
||||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
@@ -37,6 +38,9 @@ def test_classifier_output():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_binary_classifier_with_default_params():
|
def test_binary_classifier_with_default_params():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -78,6 +82,9 @@ def test_binary_classifier_with_default_params():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_multiclass_classifier():
|
def test_multiclass_classifier():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -117,6 +124,9 @@ def test_multiclass_classifier():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_default_device():
|
def test_default_device():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -129,6 +139,9 @@ def test_default_device():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_explicit_device_setup():
|
def test_explicit_device_setup():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
|
|||||||
@@ -49,19 +49,19 @@ IMAGE_HEIGHT = 224
|
|||||||
IMAGE_WIDTH = 224
|
IMAGE_WIDTH = 224
|
||||||
NUM_VIEWS = 2 # Number of camera views
|
NUM_VIEWS = 2 # Number of camera views
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
MODEL_PATH_LEROBOT = "lerobot/pi0fast-base"
|
MODEL_PATH_LEROBOT = "jadechoghari/pi0fast-base"
|
||||||
|
|
||||||
# Expected action token shape: (batch_size, max_decoding_steps)
|
# Expected action token shape: (batch_size, max_decoding_steps)
|
||||||
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
|
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
|
||||||
|
|
||||||
# Expected first 5 action tokens (for reproducibility check)
|
# Expected first 5 action tokens (for reproducibility check)
|
||||||
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362])
|
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255425])
|
||||||
|
|
||||||
# Expected actions after detokenization
|
# Expected actions after detokenization
|
||||||
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
|
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
|
||||||
EXPECTED_ACTIONS_MEAN = 0.04419417306780815
|
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
|
||||||
EXPECTED_ACTIONS_STD = 0.26231569051742554
|
EXPECTED_ACTIONS_STD = 0.2607129216194153
|
||||||
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000])
|
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([-0.0707, 1.4849, 0.0000, 0.0000, 0.0000])
|
||||||
|
|
||||||
|
|
||||||
def set_seed_all(seed: int):
|
def set_seed_all(seed: int):
|
||||||
|
|||||||
@@ -305,6 +305,9 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
|
|||||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||||
)
|
)
|
||||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
|
)
|
||||||
def test_sac_policy_with_pretrained_encoder(
|
def test_sac_policy_with_pretrained_encoder(
|
||||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user