From ed6978fd18b7d91ea59e90ee73d75545138f9ec6 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Mon, 18 May 2026 18:14:13 +0200 Subject: [PATCH] removing missleading future_action_window_size to just use chunk_size --- src/lerobot/policies/vla_jepa/action_head.py | 2 +- src/lerobot/policies/vla_jepa/configuration_vla_jepa.py | 5 +---- .../policies/vla_jepa/convert_vla_jepa_checkpoints.py | 1 - src/lerobot/policies/vla_jepa/modeling_vla_jepa.py | 2 +- tests/policies/vla_jepa/conftest.py | 1 - tests/policies/vla_jepa/test_configuration.py | 8 +------- 6 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index fa8f90508..430c9cfe9 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -210,7 +210,7 @@ class VLAJEPAActionHead(nn.Module): inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768 self.input_embedding_dim = inner_dim - self.action_horizon = config.future_action_window_size + 1 + self.action_horizon = config.chunk_size self.num_inference_timesteps = config.num_inference_timesteps self.model = DiT( diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index c3c2cd2f0..9bcff66ea 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -36,8 +36,7 @@ class VLAJEPAConfig(PreTrainedConfig): action_dim: int = 7 state_dim: int = 8 - future_action_window_size: int = 6 - past_action_window_size: int = 0 + num_action_tokens_per_timestep: int = 8 num_embodied_action_tokens_per_instruction: int = 32 num_inference_timesteps: int = 4 @@ -82,8 +81,6 @@ class VLAJEPAConfig(PreTrainedConfig): self.enable_world_model = False if self.n_action_steps > self.chunk_size: raise ValueError("`n_action_steps` must be <= `chunk_size`.") - if self.future_action_window_size + 1 > self.chunk_size: - raise ValueError("`chunk_size` must cover the predicted action horizon.") if self.num_video_frames < 2 * self.jepa_tubelet_size: raise ValueError( f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` " diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index 5f60c97b1..d6f444645 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -62,7 +62,6 @@ _ARCH = { "qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B "chunk_size": 7, "n_action_steps": 7, - "future_action_window_size": 6, "num_video_frames": 8, "jepa_tubelet_size": 2, "num_action_tokens_per_timestep": 8, diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 41183c7b6..5f2cf8a9d 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -236,7 +236,7 @@ class VLAJEPAModel(nn.Module): actions_tensor = torch.tensor( np.array(actions), device=last_hidden.device, dtype=torch.float32 ) # [B, T_full, action_dim] - action_horizon = self.config.future_action_window_size + 1 + action_horizon = self.config.chunk_size actions_target = actions_tensor[:, -action_horizon:, :] state_tensor = None diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index a1e9f9960..f5703fd82 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -62,7 +62,6 @@ def make_config( device="cpu", chunk_size=action_horizon, n_action_steps=min(N_ACTION_STEPS, action_horizon), - future_action_window_size=action_horizon - 1, action_dim=action_dim, state_dim=state_dim, num_video_frames=num_video_frames, diff --git a/tests/policies/vla_jepa/test_configuration.py b/tests/policies/vla_jepa/test_configuration.py index 34e9bcff8..2eda08ad3 100644 --- a/tests/policies/vla_jepa/test_configuration.py +++ b/tests/policies/vla_jepa/test_configuration.py @@ -18,12 +18,7 @@ def test_delta_indices() -> None: def test_n_action_steps_exceeds_chunk_size_raises() -> None: with pytest.raises(ValueError, match="n_action_steps"): - VLAJEPAConfig(chunk_size=4, n_action_steps=8, future_action_window_size=3) - - -def test_future_window_exceeds_chunk_size_raises() -> None: - with pytest.raises(ValueError, match="predicted action horizon"): - VLAJEPAConfig(chunk_size=4, n_action_steps=4, future_action_window_size=4) + VLAJEPAConfig(chunk_size=4, n_action_steps=8) def test_too_few_video_frames_raises() -> None: @@ -31,7 +26,6 @@ def test_too_few_video_frames_raises() -> None: VLAJEPAConfig( chunk_size=16, n_action_steps=16, - future_action_window_size=15, num_video_frames=2, jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0 )