Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>
This commit is contained in:
Jade Choghari
2026-02-23 22:44:13 +03:00
committed by GitHub
parent 753b996cda
commit 419305a4c2
13 changed files with 517 additions and 195 deletions
@@ -54,19 +54,19 @@ IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
NUM_VIEWS = 2 # Number of camera views
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH_LEROBOT = "lerobot/pi0fast-base"
MODEL_PATH_LEROBOT = "jadechoghari/pi0fast-base"
# Expected action token shape: (batch_size, max_decoding_steps)
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
# Expected first 5 action tokens (for reproducibility check)
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362])
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255425])
# Expected actions after detokenization
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
EXPECTED_ACTIONS_MEAN = 0.04419417306780815
EXPECTED_ACTIONS_STD = 0.26231569051742554
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000])
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
EXPECTED_ACTIONS_STD = 0.2607129216194153
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([-0.0707, 1.4849, 0.0000, 0.0000, 0.0000])
def set_seed_all(seed: int):
+1 -1
View File
@@ -24,7 +24,7 @@ import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
reason="This test requires accepting the model license",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
+1 -1
View File
@@ -26,7 +26,7 @@ from lerobot.utils.random_utils import set_seed
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
reason="This test requires accepting the model license",
)
from lerobot.policies.factory import make_policy_config # noqa: E402