optmize topreward input processing (#3660)

This commit is contained in:
Haoming Song
2026-05-25 22:07:45 +08:00
committed by GitHub
parent 616663cd9f
commit 3b5b94dbd6
10 changed files with 300 additions and 281 deletions
+73 -26
View File
@@ -24,7 +24,7 @@ import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.rewards.factory import get_reward_model_class, make_reward_model_config
from lerobot.rewards.topreward import TOPRewardConfig
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX, TOPREWARD_INPUT_KEYS
from tests.utils import skip_if_package_missing
@@ -45,20 +45,23 @@ class _FakeQwenModel(torch.nn.Module):
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
return cls()
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): # noqa: ARG002
def forward( # noqa: ARG002
self, input_ids, attention_mask=None, labels=None, logits_to_keep=0, **kwargs
):
batch_size, seq_len = input_ids.shape
vocab_size = 1000
logits = torch.zeros(batch_size, seq_len, vocab_size)
# Place a controlled log-prob at the target token position so the
# model returns a predictable reward value.
# The label-masked suffix is the last token (prompt_length = seq_len - 1).
# The label-masked suffix is the last token.
# After the causal-LM shift (logits[:, :-1], labels[:, 1:]) the scored
# position is logits[:, -2, :] predicting labels[:, -1].
# We set logits so that log_softmax at the target token ≈ _reward_value.
if labels is not None:
for i in range(batch_size):
target_idx = int(input_ids[i, -1].item())
logits[i, -2, target_idx] = self._reward_value * -10 # high logit -> high log-prob
for i in range(batch_size):
target_idx = int(input_ids[i, -1].item())
logits[i, -2, target_idx] = self._reward_value * -10 # high logit -> high log-prob
if logits_to_keep:
logits = logits[:, -logits_to_keep:, :]
return SimpleNamespace(logits=logits)
@@ -72,17 +75,39 @@ def _patch_build(monkeypatch) -> None:
def _make_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
prompt_length: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
*,
omit: str | None = None,
) -> dict[str, torch.Tensor]:
"""Build a ``compute_reward``-ready batch using TOPReward's namespaced keys."""
batch: dict[str, torch.Tensor] = {f"{TOPREWARD_FEATURE_PREFIX}input_ids": input_ids}
if attention_mask is not None:
batch[f"{TOPREWARD_FEATURE_PREFIX}attention_mask"] = attention_mask
if prompt_length is not None:
batch[f"{TOPREWARD_FEATURE_PREFIX}prompt_length"] = prompt_length
batch_size, seq_len = input_ids.shape
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
batch: dict[str, torch.Tensor] = {}
if labels is not None:
batch[f"{TOPREWARD_FEATURE_PREFIX}labels"] = labels
batch.update(
{
f"{TOPREWARD_FEATURE_PREFIX}input_ids": input_ids,
f"{TOPREWARD_FEATURE_PREFIX}attention_mask": attention_mask,
f"{TOPREWARD_FEATURE_PREFIX}pixel_values_videos": torch.zeros(
batch_size, 1536, dtype=torch.float32
),
f"{TOPREWARD_FEATURE_PREFIX}video_grid_thw": torch.ones(batch_size, 3, dtype=torch.long),
f"{TOPREWARD_FEATURE_PREFIX}mm_token_type_ids": torch.zeros_like(input_ids),
}
)
if omit is not None:
batch.pop(f"{TOPREWARD_FEATURE_PREFIX}{omit}", None)
return batch
def _terminal_labels(input_ids: torch.Tensor) -> torch.Tensor:
labels = torch.full_like(input_ids, -100)
labels[:, -1] = input_ids[:, -1]
return labels
# ---------------------------------------------------------------------------
# Registry + factory
# ---------------------------------------------------------------------------
@@ -105,11 +130,6 @@ def test_topreward_factory_returns_in_tree_class():
# ---------------------------------------------------------------------------
def test_topreward_config_rejects_bad_reduction():
with pytest.raises(ValueError, match="reduction must be"):
TOPRewardConfig(device="cpu", reduction="median")
def test_topreward_config_rejects_zero_max_frames():
with pytest.raises(ValueError, match="max_frames must be >= 1"):
TOPRewardConfig(device="cpu", max_frames=0)
@@ -142,9 +162,9 @@ def test_topreward_compute_reward_returns_one_scalar_per_sample(monkeypatch):
input_ids = torch.randint(0, 100, (2, 10))
attention_mask = torch.ones(2, 10, dtype=torch.long)
prompt_length = torch.tensor([9, 9]) # unmask only the last token
labels = _terminal_labels(input_ids)
batch = _make_batch(input_ids, attention_mask, prompt_length)
batch = _make_batch(input_ids, attention_mask, labels)
rewards = model.compute_reward(batch)
assert rewards.shape == (2,)
@@ -162,9 +182,9 @@ def test_topreward_compute_reward_applies_success_threshold(monkeypatch):
input_ids = torch.randint(0, 100, (2, 10))
attention_mask = torch.ones(2, 10, dtype=torch.long)
prompt_length = torch.tensor([9, 9])
labels = _terminal_labels(input_ids)
batch = _make_batch(input_ids, attention_mask, prompt_length)
batch = _make_batch(input_ids, attention_mask, labels)
rewards = model.compute_reward(batch)
assert rewards.shape == (2,)
@@ -180,7 +200,37 @@ def test_topreward_compute_reward_errors_when_inputs_missing(monkeypatch):
model = TOPRewardModel(cfg)
with pytest.raises(KeyError, match=r"observation\.topreward\.input_ids"):
model.compute_reward({})
model.compute_reward(_make_batch(torch.randint(0, 100, (1, 10)), omit="input_ids"))
@skip_if_package_missing("transformers")
def test_topreward_compute_reward_errors_when_labels_missing(monkeypatch):
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
_patch_build(monkeypatch)
cfg = TOPRewardConfig(device="cpu")
model = TOPRewardModel(cfg)
input_ids = torch.randint(0, 100, (1, 10))
with pytest.raises(KeyError, match=r"observation\.topreward\.labels"):
model.compute_reward(_make_batch(input_ids, labels=None))
@skip_if_package_missing("transformers")
def test_topreward_compute_reward_requires_all_encoder_keys(monkeypatch):
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
_patch_build(monkeypatch)
cfg = TOPRewardConfig(device="cpu")
model = TOPRewardModel(cfg)
input_ids = torch.randint(0, 100, (1, 10))
labels = _terminal_labels(input_ids)
required_encoder_keys = set(TOPREWARD_INPUT_KEYS) - {"input_ids", "labels"}
for key in required_encoder_keys:
with pytest.raises(KeyError, match=rf"observation\.topreward\.{key}"):
model.compute_reward(_make_batch(input_ids, labels=labels, omit=key))
# ---------------------------------------------------------------------------
@@ -198,7 +248,6 @@ def test_topreward_save_pretrained_writes_only_config_json(monkeypatch, tmp_path
cfg = TOPRewardConfig(
device="cpu",
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
reduction="sum",
fps=4.0,
image_key="observation.images.front",
)
@@ -217,7 +266,6 @@ def test_topreward_from_pretrained_local_dir_roundtrips_config(monkeypatch, tmp_
cfg = TOPRewardConfig(
device="cpu",
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
reduction="sum",
fps=4.0,
image_key="observation.images.front",
add_chat_template=True,
@@ -229,7 +277,6 @@ def test_topreward_from_pretrained_local_dir_roundtrips_config(monkeypatch, tmp_
assert isinstance(reloaded.config, TOPRewardConfig)
assert reloaded.config.vlm_name == "Qwen/Qwen3-VL-8B-Instruct"
assert reloaded.config.reduction == "sum"
assert reloaded.config.fps == 4.0
assert reloaded.config.image_key == "observation.images.front"
assert reloaded.config.add_chat_template is True
+80
View File
@@ -0,0 +1,80 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end TOPReward smoke test with the real Qwen3-VL model."""
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig # noqa: E402
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel # noqa: E402
from lerobot.rewards.topreward.processor_topreward import ( # noqa: E402
TOPREWARD_FEATURE_PREFIX,
TOPREWARD_INPUT_KEYS,
make_topreward_pre_post_processors,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires downloading and loading Qwen3-VL and is not meant for CI",
)
def _make_dummy_topreward_batch(image_key: str, task_key: str) -> dict[str, object]:
num_frames = 4
image_size = 64
frames = torch.zeros(1, num_frames, 3, image_size, image_size, dtype=torch.uint8)
for frame_idx in range(num_frames):
frames[0, frame_idx, 0].fill_(min(frame_idx * 48, 255))
frames[0, frame_idx, 1].fill_(96)
frames[0, frame_idx, 2].fill_(192)
return {
image_key: frames,
task_key: ["pick up the red cube"],
}
@require_cuda
def test_topreward_full_qwen3vl_preprocessor_to_compute_reward():
cfg = TOPRewardConfig(
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
device="cuda",
max_frames=4,
fps=2.0,
max_input_length=4096,
)
preprocessor, _ = make_topreward_pre_post_processors(cfg)
encoded_batch = preprocessor(_make_dummy_topreward_batch(cfg.image_key, cfg.task_key))
for key in TOPREWARD_INPUT_KEYS:
assert f"{TOPREWARD_FEATURE_PREFIX}{key}" in encoded_batch
model = TOPRewardModel(cfg)
try:
model.to(cfg.device)
model.eval()
rewards = model.compute_reward(encoded_batch)
finally:
del model
torch.cuda.empty_cache()
assert rewards.shape == (1,)
assert rewards.dtype == torch.float32
assert torch.isfinite(rewards).all()
+52 -70
View File
@@ -16,71 +16,71 @@
from __future__ import annotations
import numpy as np
import pytest
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.rewards.topreward.processor_topreward import (
TOPREWARD_FEATURE_PREFIX,
TOPREWARD_INPUT_KEYS,
_expand_tasks,
_video_to_numpy,
_prepare_video_batch,
)
from lerobot.types import TransitionKey
from tests.utils import skip_if_package_missing
# ---------------------------------------------------------------------------
# _video_to_numpy — pure (T, C, H, W) -> (T, H, W, C) uint8 conversion
# _prepare_video_batch — raw image/video batch -> (B, T, C, H, W) uint8
# ---------------------------------------------------------------------------
def test_video_to_numpy_chw_float_is_converted_to_thwc_uint8():
video = torch.rand(4, 3, 8, 8)
array = _video_to_numpy(video, max_frames=None)
def test_prepare_video_batch_batched_chw_float_is_converted_to_uint8():
video = torch.rand(2, 4, 3, 8, 8)
tensor = _prepare_video_batch(video, max_frames=None)
assert array.shape == (4, 8, 8, 3)
assert array.dtype == np.uint8
assert array.min() >= 0 and array.max() <= 255
assert tensor.shape == (2, 4, 3, 8, 8)
assert tensor.dtype == torch.uint8
assert tensor.min() >= 0 and tensor.max() <= 255
def test_video_to_numpy_already_thwc_uint8_passes_through():
video = torch.randint(0, 256, (3, 8, 8, 3), dtype=torch.uint8)
array = _video_to_numpy(video, max_frames=None)
def test_prepare_video_batch_batched_thwc_uint8_is_permuted_to_channel_first():
video = torch.randint(0, 256, (2, 3, 8, 8, 3), dtype=torch.uint8)
tensor = _prepare_video_batch(video, max_frames=None)
assert array.shape == (3, 8, 8, 3)
assert array.dtype == np.uint8
assert tensor.shape == (2, 3, 3, 8, 8)
assert tensor.dtype == torch.uint8
def test_video_to_numpy_max_frames_tail_crops_recent_frames():
video = torch.zeros(10, 3, 4, 4)
def test_prepare_video_batch_max_frames_tail_crops_recent_frames():
video = torch.zeros(1, 10, 3, 4, 4)
for t in range(10):
video[t] = t / 9.0
video[:, t] = t / 9.0
array = _video_to_numpy(video, max_frames=3)
tensor = _prepare_video_batch(video, max_frames=3)
assert array.shape == (3, 4, 4, 3)
assert int(array[0, 0, 0, 0]) == int(round(7 / 9 * 255))
assert int(array[-1, 0, 0, 0]) == 255
assert tensor.shape == (1, 3, 3, 4, 4)
assert int(tensor[0, 0, 0, 0, 0]) == int(7 / 9 * 255)
assert int(tensor[0, -1, 0, 0, 0]) == 255
def test_video_to_numpy_rejects_3d_input():
with pytest.raises(ValueError, match="Expected channel dim"):
_video_to_numpy(torch.zeros(4, 8, 8), max_frames=None)
def test_prepare_video_batch_rejects_3d_input():
with pytest.raises(ValueError, match="Expected TOPReward frames"):
_prepare_video_batch(torch.zeros(4, 8, 8), max_frames=None)
def test_video_to_numpy_floats_above_one_pass_through_without_rescaling():
video = torch.full((1, 3, 2, 2), 5.0)
array = _video_to_numpy(video, max_frames=None)
def test_prepare_video_batch_floats_above_one_are_rescaled_and_clipped():
video = torch.full((1, 1, 3, 2, 2), 5.0)
tensor = _prepare_video_batch(video, max_frames=None)
assert array.shape == (1, 2, 2, 3)
assert int(array.max()) == 5
assert tensor.shape == (1, 1, 3, 2, 2)
assert int(tensor.max()) == 255
def test_video_to_numpy_clips_very_large_floats_to_uint8_max():
video = torch.full((1, 3, 2, 2), 300.0)
array = _video_to_numpy(video, max_frames=None)
def test_prepare_video_batch_clips_very_large_floats_to_uint8_max():
video = torch.full((1, 1, 3, 2, 2), 300.0)
tensor = _prepare_video_batch(video, max_frames=None)
assert int(array.max()) == 255
assert int(tensor.max()) == 255
# ---------------------------------------------------------------------------
@@ -124,12 +124,11 @@ def test_expand_tasks_wrong_type_raises():
# ---------------------------------------------------------------------------
# Encoder step — stubbed AutoProcessor + process_vision_info
# Encoder step — stubbed AutoProcessor
# ---------------------------------------------------------------------------
def _skip_if_topreward_extras_missing(func):
func = skip_if_package_missing("qwen-vl-utils", import_name="qwen_vl_utils")(func)
func = skip_if_package_missing("transformers")(func)
return func
@@ -155,32 +154,20 @@ class _FakeAutoProcessor:
def __call__(self, text=None, images=None, videos=None, **kwargs): # noqa: ARG002
seq_len = 10
batch_size = len(text) if isinstance(text, list) else 1
return {
"input_ids": torch.randint(0, 100, (1, seq_len)),
"attention_mask": torch.ones(1, seq_len, dtype=torch.long),
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
"pixel_values_videos": torch.zeros(batch_size, 1536, dtype=torch.float32),
"video_grid_thw": torch.ones(batch_size, 3, dtype=torch.long),
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.long),
}
def _build_step(monkeypatch, **overrides):
import importlib
import sys
import types
from lerobot.rewards.topreward import processor_topreward
from lerobot.utils import import_utils
monkeypatch.setattr(processor_topreward, "AutoProcessor", _FakeAutoProcessor)
# Stub qwen_vl_utils as a real module object (not MagicMock) so
# ``require_package`` / ``find_spec`` don't choke on a missing ``__spec__``.
fake_qwen_vl = types.ModuleType("qwen_vl_utils")
fake_qwen_vl.process_vision_info = lambda messages: (None, None) # type: ignore[attr-defined]
fake_qwen_vl.__spec__ = importlib.machinery.ModuleSpec("qwen_vl_utils", None)
monkeypatch.setitem(sys.modules, "qwen_vl_utils", fake_qwen_vl)
# Clear the require_package cache so the stub is picked up.
import_utils._require_package_cache.pop("qwen_vl_utils", None)
return processor_topreward.TOPRewardEncoderProcessorStep(**overrides)
@@ -192,27 +179,29 @@ def _make_transition(observation: dict, complementary: dict | None = None) -> di
@_skip_if_topreward_extras_missing
def test_encoder_step_emits_input_ids_and_prompt_length(monkeypatch):
def test_encoder_step_emits_input_ids_and_labels(monkeypatch):
"""The processor must emit Qwen-VL tensors including ``input_ids`` and
``prompt_length`` under the ``observation.topreward.*`` namespace."""
``labels`` under the ``observation.topreward.*`` namespace."""
step = _build_step(monkeypatch)
frames_batch = torch.zeros(1, 4, 3, 8, 8)
frames_batch = torch.zeros(2, 4, 3, 8, 8)
out = step(
_make_transition(
observation={"observation.images.top": frames_batch},
complementary={"task": "pick"},
complementary={"task": ["pick", "place"]},
)
)
obs_out = out[TransitionKey.OBSERVATION]
assert f"{TOPREWARD_FEATURE_PREFIX}input_ids" in obs_out
assert f"{TOPREWARD_FEATURE_PREFIX}attention_mask" in obs_out
assert f"{TOPREWARD_FEATURE_PREFIX}prompt_length" in obs_out
for key in TOPREWARD_INPUT_KEYS:
assert f"{TOPREWARD_FEATURE_PREFIX}{key}" in obs_out
prompt_length = obs_out[f"{TOPREWARD_FEATURE_PREFIX}prompt_length"]
assert prompt_length.dtype == torch.long
assert prompt_length.shape == (1,)
input_ids = obs_out[f"{TOPREWARD_FEATURE_PREFIX}input_ids"]
labels = obs_out[f"{TOPREWARD_FEATURE_PREFIX}labels"]
assert labels.dtype == torch.long
assert labels.shape == (2, 10)
assert labels[:, :-1].eq(-100).all()
assert labels[:, -1].equal(input_ids[:, -1])
@_skip_if_topreward_extras_missing
@@ -255,10 +244,3 @@ def test_encoder_step_rejects_missing_image_key(monkeypatch):
step = _build_step(monkeypatch, image_key="observation.images.top")
with pytest.raises(KeyError, match="image key"):
step(_make_transition(observation={}, complementary={"task": "pick"}))
@_skip_if_topreward_extras_missing
def test_encoder_step_rejects_non_dict_observation(monkeypatch):
step = _build_step(monkeypatch)
with pytest.raises(ValueError, match="observation dict"):
step({TransitionKey.OBSERVATION: torch.zeros(1, 3, 8, 8)})