mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(ci): skip HF log in (and tests) in forks and community PRs (#3097)
* fix(ci): skip HF log in (and tests) in forks and community PRs * chore(test): remove comment about test meant to be only run locally * fix(tests): no hf log in decorator for xvla * fix(test): no decorator in yield
This commit is contained in:
@@ -91,6 +91,7 @@ jobs:
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
@@ -89,6 +89,7 @@ jobs:
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
@@ -181,6 +182,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
|
||||
@@ -132,6 +132,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
@@ -164,6 +165,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
@@ -197,6 +199,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
|
||||
@@ -81,6 +81,7 @@ jobs:
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
@@ -154,6 +155,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
) # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
# Constants
|
||||
DUMMY_ACTION_DIM = 7
|
||||
@@ -65,6 +65,7 @@ EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000]
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
random.seed(seed)
|
||||
@@ -82,6 +83,7 @@ def set_seed_all(seed: int):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def instantiate_lerobot_pi0_fast(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH_LEROBOT,
|
||||
@@ -125,6 +127,7 @@ def instantiate_lerobot_pi0_fast(
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 1
|
||||
@@ -157,6 +160,7 @@ def create_dummy_data(device=DEVICE):
|
||||
# Pytest fixtures
|
||||
@pytest.fixture(scope="module")
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def pi0_fast_components():
|
||||
"""Fixture to instantiate and provide all PI0Fast components for tests."""
|
||||
print(f"\nTesting with DEVICE='{DEVICE}'")
|
||||
@@ -168,6 +172,7 @@ def pi0_fast_components():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def policy(pi0_fast_components):
|
||||
"""Fixture to provide the PI0Fast policy for tests."""
|
||||
return pi0_fast_components[0]
|
||||
@@ -175,12 +180,14 @@ def policy(pi0_fast_components):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def preprocessor(pi0_fast_components):
|
||||
"""Fixture to provide the PI0Fast preprocessor for tests."""
|
||||
return pi0_fast_components[1]
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
||||
"""Test that LeRobot PI0Fast preprocessor produces expected outputs."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -228,6 +235,7 @@ def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_action_generation(policy, preprocessor):
|
||||
"""Test PI0Fast LeRobot implementation generates expected actions."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -306,6 +314,7 @@ def test_pi0_fast_action_generation(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
||||
"""Test that PI0Fast inference is reproducible with the same seed."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -347,6 +356,7 @@ def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
||||
"""Test PI0Fast forward pass and compare logits against expected values."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -396,6 +406,7 @@ def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
||||
"""Test PI0Fast action token sampling (autoregressive decoding)."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -452,6 +463,7 @@ def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_detokenization(policy, preprocessor):
|
||||
"""Test PI0Fast action detokenization (FAST decoding)."""
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
|
||||
"""Test script to verify PI0 policy integration with LeRobot"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
@@ -25,10 +28,11 @@ from lerobot.policies.pi0 import ( # noqa: E402
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
@@ -105,6 +109,7 @@ def test_policy_instantiation():
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi05 import ( # noqa: E402
|
||||
PI05Config,
|
||||
@@ -25,10 +28,11 @@ from lerobot.policies.pi05 import ( # noqa: E402
|
||||
make_pi05_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
@@ -141,6 +145,7 @@ def test_policy_instantiation():
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
||||
"""Test script to verify Wall-X policy integration with LeRobot"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -29,10 +29,11 @@ from lerobot.policies.wall_x import WallXConfig # noqa: E402
|
||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
|
||||
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
@@ -118,6 +119,7 @@ def test_policy_instantiation():
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation"""
|
||||
# ruff: noqa: E402
|
||||
|
||||
import random
|
||||
|
||||
@@ -108,6 +108,22 @@ def require_cuda(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_hf_token(func):
|
||||
"""
|
||||
Decorator that skips the test if no Hugging Face Hub token is available.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from huggingface_hub import get_token
|
||||
|
||||
if get_token() is None:
|
||||
pytest.skip("requires HF token for gated model access")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_env(func):
|
||||
"""
|
||||
Decorator that skips the test if the required environment package is not installed.
|
||||
|
||||
Reference in New Issue
Block a user