Files
lerobot/tests/policies/vla_jepa/test_configuration.py
T
2026-05-13 15:55:04 +02:00

64 lines
2.3 KiB
Python

#!/usr/bin/env python
from __future__ import annotations
import pytest
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def test_delta_indices() -> None:
config = make_config()
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
assert config.action_delta_indices == list(range(ACTION_HORIZON))
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="n_action_steps"):
VLAJEPAConfig(chunk_size=4, n_action_steps=8, future_action_window_size=3)
def test_future_window_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="predicted action horizon"):
VLAJEPAConfig(chunk_size=4, n_action_steps=4, future_action_window_size=4)
def test_too_few_video_frames_raises() -> None:
with pytest.raises(ValueError, match="video_horizon"):
VLAJEPAConfig(
chunk_size=16,
n_action_steps=16,
future_action_window_size=15,
num_video_frames=2,
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
)
def test_validate_features_no_image_raises() -> None:
config = VLAJEPAConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
with pytest.raises(ValueError, match="at least one visual input feature"):
config.validate_features()
def test_validate_features_no_action_raises() -> None:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
},
output_features={},
)
with pytest.raises(ValueError, match="action output feature"):
config.validate_features()
def test_validate_features_sets_action_dim_from_feature() -> None:
config = make_config(action_dim=6, state_dim=10)
assert config.action_dim == 6
assert config.state_dim == 10