mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
|
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
|
|
|
|
|
def test_delta_indices() -> None:
|
|
config = make_config()
|
|
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
|
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
|
|
|
|
|
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
|
with pytest.raises(ValueError, match="n_action_steps"):
|
|
VLAJEPAConfig(chunk_size=4, n_action_steps=8, future_action_window_size=3)
|
|
|
|
|
|
def test_future_window_exceeds_chunk_size_raises() -> None:
|
|
with pytest.raises(ValueError, match="predicted action horizon"):
|
|
VLAJEPAConfig(chunk_size=4, n_action_steps=4, future_action_window_size=4)
|
|
|
|
|
|
def test_too_few_video_frames_raises() -> None:
|
|
with pytest.raises(ValueError, match="video_horizon"):
|
|
VLAJEPAConfig(
|
|
chunk_size=16,
|
|
n_action_steps=16,
|
|
future_action_window_size=15,
|
|
num_video_frames=2,
|
|
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
|
)
|
|
|
|
|
|
def test_validate_features_no_image_raises() -> None:
|
|
config = VLAJEPAConfig(
|
|
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
|
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
|
)
|
|
with pytest.raises(ValueError, match="at least one visual input feature"):
|
|
config.validate_features()
|
|
|
|
|
|
def test_validate_features_no_action_raises() -> None:
|
|
config = VLAJEPAConfig(
|
|
input_features={
|
|
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
|
},
|
|
output_features={},
|
|
)
|
|
with pytest.raises(ValueError, match="action output feature"):
|
|
config.validate_features()
|
|
|
|
|
|
def test_validate_features_sets_action_dim_from_feature() -> None:
|
|
config = make_config(action_dim=6, state_dim=10)
|
|
assert config.action_dim == 6
|
|
assert config.state_dim == 10
|