mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
Fix policy testing for tv5 (#3032)
* fix ci logger * other fix * fix mypy * change logits to torch2.10 * skip wallx| * remove logging --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -60,13 +60,13 @@ MODEL_PATH_LEROBOT = "jadechoghari/pi0fast-base"
|
|||||||
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
|
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
|
||||||
|
|
||||||
# Expected first 5 action tokens (for reproducibility check)
|
# Expected first 5 action tokens (for reproducibility check)
|
||||||
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255425])
|
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255020, 255589])
|
||||||
|
|
||||||
# Expected actions after detokenization
|
# Expected actions after detokenization
|
||||||
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
|
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
|
||||||
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
|
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
|
||||||
EXPECTED_ACTIONS_STD = 0.2607129216194153
|
EXPECTED_ACTIONS_STD = 0.2607129216194153
|
||||||
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([-0.0707, 1.4849, 0.0000, 0.0000, 0.0000])
|
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000])
|
||||||
|
|
||||||
|
|
||||||
def set_seed_all(seed: int):
|
def set_seed_all(seed: int):
|
||||||
|
|||||||
@@ -16,9 +16,15 @@
|
|||||||
|
|
||||||
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
|
reason="This test exceeds available memory in CI environments.",
|
||||||
|
)
|
||||||
# Skip if required dependencies are not available
|
# Skip if required dependencies are not available
|
||||||
pytest.importorskip("peft")
|
pytest.importorskip("peft")
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
|
|||||||
Reference in New Issue
Block a user