mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
Revert "fix(deps): breaking change from transformers 5.4.0" (#3249)
* Revert "fix(deps): breaking change from transformers 5.4.0 (#3231)"
This reverts commit 07502868e5.
* chore(dependecies): pin transformers to 5.3.0 temporarily
This commit is contained in:
+1
-1
@@ -99,7 +99,7 @@ dependencies = [
|
|||||||
# Common
|
# Common
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||||
transformers-dep = ["transformers>=5.4.0,<6.0.0"]
|
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -110,6 +110,7 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -173,6 +173,7 @@ N_COLOR_CHANNELS = 3
|
|||||||
|
|
||||||
|
|
||||||
# config
|
# config
|
||||||
|
@dataclass
|
||||||
class GR00TN15Config(PretrainedConfig):
|
class GR00TN15Config(PretrainedConfig):
|
||||||
model_type = "gr00t_n1_5"
|
model_type = "gr00t_n1_5"
|
||||||
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from transformers.utils import (
|
|||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
@@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
|||||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from transformers.utils import (
|
|||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|||||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|||||||
Reference in New Issue
Block a user