refactoring into using pre and post processor

This commit is contained in:
Maxime Ellerbach
2026-05-21 11:15:52 +00:00
committed by Maximellerbach
parent 51e57789ba
commit 47f8a50fa0
7 changed files with 268 additions and 114 deletions
+39 -1
View File
@@ -111,6 +111,41 @@ def make_inference_batch(
# ---------------------------------------------------------------------------
class _FakeLanguageLayer(nn.Module):
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
return (hidden,)
class _FakeLanguageModel(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
self.layers[-1](hidden)
return SimpleNamespace()
class _FakeQwenInnerModel(nn.Module):
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.language_model = _FakeLanguageModel(hidden_size)
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
return self.language_model(input_ids)
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
@@ -119,6 +154,7 @@ class _FakeQwenBackbone(nn.Module):
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
self.model = _FakeQwenInnerModel(hidden_size)
@property
def device(self) -> torch.device:
@@ -189,7 +225,9 @@ class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
@property
def device(self) -> torch.device:
+135 -20
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
import os
from copy import deepcopy
import numpy as np
import pytest
import torch
from torch import Tensor
@@ -206,12 +207,11 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
# ---------------------------------------------------------------------------
def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None:
import numpy as np
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._lerobot_to_native(make_train_batch())
examples = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
@@ -222,44 +222,35 @@ def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None)
assert ex["state"].shape == (1, STATE_DIM)
def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._lerobot_to_native(make_inference_batch()):
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._lerobot_to_native(batch)
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch))
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._lerobot_to_native(batch))
def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)})
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5)
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
# ---------------------------------------------------------------------------
@@ -355,3 +346,127 @@ def test_hub_libero_inference_shape() -> None:
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim
# ---------------------------------------------------------------------------
# Postprocessor unnormalization tests
#
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
# ---------------------------------------------------------------------------
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
from lerobot.utils.constants import ACTION
return {
ACTION: {
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
}
}
@torch.no_grad()
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
rng = np.random.default_rng(7)
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
actions_tensor = torch.from_numpy(actions_np)
transition = policy_action_to_transition(actions_tensor)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
# Deliberately out-of-range inputs
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
clipped = np.clip(actions_np, -1.0, 1.0)
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
clip_step = ClipActionsProcessorStep()
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
transition = policy_action_to_transition(torch.from_numpy(actions_np))
transition = clip_step(transition)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_applied_after_predict_action_chunk(
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
Verifies the split: predict_action_chunk returns normalized actions, and calling the
postprocessor on them produces the correctly unnormalized result.
"""
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
cfg = make_config()
cfg.clip_normalized_actions = False
cfg.binarize_gripper_action = False
policy = VLAJEPAPolicy(cfg)
policy.eval()
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
dataset_stats = _make_dataset_stats()
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
# predict_action_chunk returns raw (normalized) actions
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
"predict_action_chunk should return raw actions without unnormalization applied."
)
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
unnormed = postprocessor(chunk)
from lerobot.utils.constants import ACTION
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
+14 -9
View File
@@ -15,10 +15,15 @@ _ACTION_EMBED_DIM = 8
def _make_predictor(
embed_dim: int = 8,
action_embed_dim: int = _ACTION_EMBED_DIM,
predictor_embed_dim: int = 16,
predictor_embed_dim: int = 24,
num_action_tokens: int = 2,
tokens_per_frame: int = 1,
) -> ActionConditionedVideoPredictor:
return ActionConditionedVideoPredictor(
num_frames=1,
img_size=(1, tokens_per_frame),
patch_size=1,
tubelet_size=1,
embed_dim=embed_dim,
action_embed_dim=action_embed_dim,
predictor_embed_dim=predictor_embed_dim,
@@ -38,16 +43,16 @@ def _make_predictor(
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM)
frame_tokens = torch.randn(batch, num_steps, tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps, 2, _ACTION_EMBED_DIM)
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame)
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)
assert tuple(out.shape) == (batch, num_steps, tokens_per_frame, embed_dim)
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
assert torch.isfinite(out).all()
def test_predictor_step_mismatch_raises() -> None:
predictor = _make_predictor()
frame_tokens = torch.randn(2, 3, 4, 8)
with pytest.raises(ValueError, match="Expected 3 action steps"):
predictor(frame_tokens, torch.randn(2, 2, 2, 8))
predictor = _make_predictor(tokens_per_frame=4)
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
with pytest.raises(RuntimeError):
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch