refactor(rewards): clean up TOPReward processor/model

This commit is contained in:
Khalil Meftah
2026-05-20 17:39:21 +02:00
parent 70ad322676
commit f6ecb7b955
7 changed files with 568 additions and 928 deletions
+106 -95
View File
@@ -23,11 +23,11 @@ import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.rewards.topreward.processor_topreward import (
TOPREWARD_FEATURE_PREFIX,
TOPRewardEncoderProcessorStep,
_expand_tasks,
_video_to_numpy,
)
from lerobot.types import TransitionKey
from tests.utils import skip_if_package_missing
# ---------------------------------------------------------------------------
# _video_to_numpy — pure (T, C, H, W) -> (T, H, W, C) uint8 conversion
@@ -35,7 +35,7 @@ from lerobot.types import TransitionKey
def test_video_to_numpy_chw_float_is_converted_to_thwc_uint8():
video = torch.rand(4, 3, 8, 8) # (T, C, H, W) floats in [0, 1]
video = torch.rand(4, 3, 8, 8)
array = _video_to_numpy(video, max_frames=None)
assert array.shape == (4, 8, 8, 3)
@@ -52,7 +52,6 @@ def test_video_to_numpy_already_thwc_uint8_passes_through():
def test_video_to_numpy_max_frames_tail_crops_recent_frames():
"""``max_frames`` should keep the **last** K frames (most recent)."""
video = torch.zeros(10, 3, 4, 4)
for t in range(10):
video[t] = t / 9.0
@@ -70,8 +69,6 @@ def test_video_to_numpy_rejects_3d_input():
def test_video_to_numpy_floats_above_one_pass_through_without_rescaling():
"""If ``array.max() > 1`` the helper assumes the tensor is already in the
uint8 range; values pass through unchanged (but are still clipped to 255)."""
video = torch.full((1, 3, 2, 2), 5.0)
array = _video_to_numpy(video, max_frames=None)
@@ -127,50 +124,80 @@ def test_expand_tasks_wrong_type_raises():
# ---------------------------------------------------------------------------
# Encoder step — input/output shapes + dataclass surface
# Encoder step — stubbed AutoProcessor + process_vision_info
# ---------------------------------------------------------------------------
def _skip_if_topreward_extras_missing(func):
func = skip_if_package_missing("qwen-vl-utils", import_name="qwen_vl_utils")(func)
func = skip_if_package_missing("transformers")(func)
return func
class _FakeTokenizer:
eos_token = "<|endoftext|>"
pad_token = "<|endoftext|>"
def __call__(self, *args, **kwargs):
return {"input_ids": torch.zeros(1, 10, dtype=torch.long)}
class _FakeAutoProcessor:
def __init__(self) -> None:
self.tokenizer = _FakeTokenizer()
@classmethod
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
return cls()
def apply_chat_template(self, messages, **kwargs): # noqa: ARG002
return "fake_prompt_text"
def __call__(self, text=None, images=None, videos=None, **kwargs): # noqa: ARG002
seq_len = 10
return {
"input_ids": torch.randint(0, 100, (1, seq_len)),
"attention_mask": torch.ones(1, seq_len, dtype=torch.long),
}
def _build_step(monkeypatch, **overrides):
import importlib
import sys
import types
from lerobot.rewards.topreward import processor_topreward
from lerobot.utils import import_utils
monkeypatch.setattr(processor_topreward, "AutoProcessor", _FakeAutoProcessor)
# Stub qwen_vl_utils as a real module object (not MagicMock) so
# ``require_package`` / ``find_spec`` don't choke on a missing ``__spec__``.
fake_qwen_vl = types.ModuleType("qwen_vl_utils")
fake_qwen_vl.process_vision_info = lambda messages: (None, None) # type: ignore[attr-defined]
fake_qwen_vl.__spec__ = importlib.machinery.ModuleSpec("qwen_vl_utils", None)
monkeypatch.setitem(sys.modules, "qwen_vl_utils", fake_qwen_vl)
# Clear the require_package cache so the stub is picked up.
import_utils._require_package_cache.pop("qwen_vl_utils", None)
return processor_topreward.TOPRewardEncoderProcessorStep(**overrides)
def _make_transition(observation: dict, complementary: dict | None = None) -> dict:
"""Build a tiny ``EnvTransition`` dict for the encoder step."""
transition: dict = {TransitionKey.OBSERVATION: observation}
if complementary is not None:
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary
return transition
def test_encoder_step_writes_namespaced_frames_and_task():
"""The encoder step's output is the contract the model reads from. It
must populate exactly two namespaced keys: ``frames`` and ``task``."""
step = TOPRewardEncoderProcessorStep(
image_key="observation.images.top",
task_key="task",
max_frames=None,
)
@_skip_if_topreward_extras_missing
def test_encoder_step_emits_input_ids_and_prompt_length(monkeypatch):
"""The processor must emit Qwen-VL tensors including ``input_ids`` and
``prompt_length`` under the ``observation.topreward.*`` namespace."""
step = _build_step(monkeypatch)
frames_batch = torch.zeros(2, 4, 3, 8, 8) # (B=2, T=4, C, H, W)
out = step(
_make_transition(
observation={"observation.images.top": frames_batch},
complementary={"task": ["pick", "place"]},
)
)
obs_out = out[TransitionKey.OBSERVATION]
frames_out = obs_out[f"{TOPREWARD_FEATURE_PREFIX}frames"]
tasks_out = obs_out[f"{TOPREWARD_FEATURE_PREFIX}task"]
assert len(frames_out) == 2
assert all(arr.shape == (4, 8, 8, 3) and arr.dtype == np.uint8 for arr in frames_out)
assert tasks_out == ["pick", "place"]
def test_encoder_step_adds_singleton_time_dim_for_4d_input():
"""A ``(B, C, H, W)`` observation is the single-frame case; the encoder
must unsqueeze the time dim so the model still sees a video."""
step = TOPRewardEncoderProcessorStep(image_key="observation.images.top", max_frames=None)
frames_batch = torch.zeros(1, 3, 8, 8) # (B=1, C, H, W) — no time dim
frames_batch = torch.zeros(1, 4, 3, 8, 8)
out = step(
_make_transition(
observation={"observation.images.top": frames_batch},
@@ -178,76 +205,60 @@ def test_encoder_step_adds_singleton_time_dim_for_4d_input():
)
)
frames_out = out[TransitionKey.OBSERVATION][f"{TOPREWARD_FEATURE_PREFIX}frames"]
assert len(frames_out) == 1
assert frames_out[0].shape == (1, 8, 8, 3) # (T=1, H, W, C)
obs_out = out[TransitionKey.OBSERVATION]
assert f"{TOPREWARD_FEATURE_PREFIX}input_ids" in obs_out
assert f"{TOPREWARD_FEATURE_PREFIX}attention_mask" in obs_out
assert f"{TOPREWARD_FEATURE_PREFIX}prompt_length" in obs_out
prompt_length = obs_out[f"{TOPREWARD_FEATURE_PREFIX}prompt_length"]
assert prompt_length.dtype == torch.long
assert prompt_length.shape == (1,)
def test_encoder_step_uses_default_task_when_complementary_is_missing():
step = TOPRewardEncoderProcessorStep(
image_key="observation.images.top",
default_task="perform the task",
)
frames_batch = torch.zeros(1, 2, 3, 4, 4)
out = step(_make_transition(observation={"observation.images.top": frames_batch}))
tasks_out = out[TransitionKey.OBSERVATION][f"{TOPREWARD_FEATURE_PREFIX}task"]
assert tasks_out == ["perform the task"]
def test_encoder_step_rejects_missing_image_key():
step = TOPRewardEncoderProcessorStep(image_key="observation.images.top")
with pytest.raises(KeyError, match="image key"):
step(_make_transition(observation={}, complementary={"task": "pick"}))
def test_encoder_step_rejects_non_dict_observation():
step = TOPRewardEncoderProcessorStep()
with pytest.raises(ValueError, match="observation dict"):
step({TransitionKey.OBSERVATION: torch.zeros(1, 3, 8, 8)})
def test_encoder_step_rejects_3d_or_6d_input():
"""The encoder accepts ``(B,C,H,W)`` or ``(B,T,C,H,W)`` only."""
step = TOPRewardEncoderProcessorStep(image_key="observation.images.top")
with pytest.raises(ValueError, match=r"\(B,C,H,W\)"):
step(
_make_transition(
observation={"observation.images.top": torch.zeros(8, 8, 3)},
complementary={"task": "pick"},
)
)
def test_encoder_step_get_config_roundtrips_user_fields():
"""``get_config`` must serialise every user-tunable field — these are
what the processor pipeline saves under ``preprocessor_config.json``."""
step = TOPRewardEncoderProcessorStep(
@_skip_if_topreward_extras_missing
def test_encoder_step_get_config_roundtrips_user_fields(monkeypatch):
step = _build_step(
monkeypatch,
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
image_key="observation.images.cam_top",
task_key="task",
default_task="do the thing",
max_frames=8,
fps=4.0,
add_chat_template=True,
max_length=2048,
)
assert step.get_config() == {
"image_key": "observation.images.cam_top",
"task_key": "task",
"default_task": "do the thing",
"max_frames": 8,
}
cfg = step.get_config()
assert cfg["vlm_name"] == "Qwen/Qwen3-VL-8B-Instruct"
assert cfg["image_key"] == "observation.images.cam_top"
assert cfg["default_task"] == "do the thing"
assert cfg["max_frames"] == 8
assert cfg["fps"] == 4.0
assert cfg["add_chat_template"] is True
assert cfg["max_length"] == 2048
def test_encoder_step_transform_features_is_identity():
"""The encoder writes plain Python objects (numpy arrays / strings)
into ``observation`` at call time but does NOT advertise new typed
features at pipeline-build time — the model reads them via the
``TOPREWARD_FEATURE_PREFIX`` namespace, not via the typed feature map.
"""
step = TOPRewardEncoderProcessorStep()
@_skip_if_topreward_extras_missing
def test_encoder_step_transform_features_is_identity(monkeypatch):
step = _build_step(monkeypatch)
features = {
PipelineFeatureType.OBSERVATION: {
"observation.images.top": PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL),
}
}
assert step.transform_features(features) == features
@_skip_if_topreward_extras_missing
def test_encoder_step_rejects_missing_image_key(monkeypatch):
step = _build_step(monkeypatch, image_key="observation.images.top")
with pytest.raises(KeyError, match="image key"):
step(_make_transition(observation={}, complementary={"task": "pick"}))
@_skip_if_topreward_extras_missing
def test_encoder_step_rejects_non_dict_observation(monkeypatch):
step = _build_step(monkeypatch)
with pytest.raises(ValueError, match="observation dict"):
step({TransitionKey.OBSERVATION: torch.zeros(1, 3, 8, 8)})