This commit is contained in:
Maximellerbach
2026-05-11 16:49:18 +02:00
parent 3144029814
commit 2e9ba42e1b
4 changed files with 24 additions and 34 deletions
+3 -11
View File
@@ -17,7 +17,6 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
)
@@ -81,10 +80,7 @@ class _FakeQwenInterface(nn.Module):
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [
self.config.special_action_token.format(idx)
for idx in range(max_action_tokens)
]
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
@@ -226,9 +222,7 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
loss.backward()
assert any(
param.grad is not None
for param in policy.model.action_model.parameters()
if param.requires_grad
param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad
)
assert set(batch) == set(batch_before)
for key, value in batch.items():
@@ -299,9 +293,7 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
local_files_only=True,
)
except LocalEntryNotFoundError:
pytest.skip(
f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache."
)
pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.")
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)