mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 70ad322676 |
@@ -73,6 +73,8 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
# TOPReward
|
||||
|
||||
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
|
||||
|
||||
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
|
||||
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
|
||||
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
|
||||
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Overview
|
||||
|
||||
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
|
||||
|
||||
- A trajectory video (a sequence of frames).
|
||||
- A task instruction (e.g. _"open the drawer"_).
|
||||
|
||||
it builds a chat prompt of the form
|
||||
|
||||
```text
|
||||
<video>
|
||||
"The above video shows a robot manipulation trajectory that completes the
|
||||
following task: <instruction> Decide whether the above statement is True
|
||||
or not. The answer is: True"
|
||||
```
|
||||
|
||||
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward, and answers the question "given this video, how strongly does the VLM agree that the instruction is satisfied?".
|
||||
|
||||
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the prompt assembly + label-masking logic.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot following the [Installation Guide](./installation).
|
||||
2. Install the TOPReward optional extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[topreward]"
|
||||
```
|
||||
|
||||
or, with `uv` from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra topreward
|
||||
```
|
||||
|
||||
This pulls in `transformers` and `qwen-vl-utils`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
TOPReward expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | --------------------------- | ----------------------------------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
|
||||
| `reward_model.max_frames` | `16` | Cap on frames per sample (compute_reward only; predict_curves bypasses) |
|
||||
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
|
||||
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
|
||||
|
||||
The model returns:
|
||||
|
||||
- `compute_reward(batch)`: one log-probability per sample. Higher = better task–video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
|
||||
- `predict_curves(batch, num_prefixes=None)`: per-frame progress curve in `[0, 1]` (min-max normalised log-probs over prefix lengths). `num_prefixes=None` is fully dense; `num_prefixes=15` matches the upstream sparse-dense default with linear interpolation between anchors.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the reward model directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
)
|
||||
reward_model = TOPRewardModel(cfg)
|
||||
```
|
||||
|
||||
There is no `from_pretrained` weight download for TOPReward itself — the VLM is fetched from the Hub on construction.
|
||||
|
||||
### Score a clip + task
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX
|
||||
|
||||
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
|
||||
# task: str
|
||||
batch = {
|
||||
f"{TOPREWARD_FEATURE_PREFIX}frames": [frames],
|
||||
f"{TOPREWARD_FEATURE_PREFIX}task": [task],
|
||||
}
|
||||
reward = reward_model.compute_reward(batch) # tensor of shape (1,)
|
||||
```
|
||||
|
||||
For a dense per-frame curve over the same clip:
|
||||
|
||||
```python
|
||||
out = reward_model.predict_curves(batch, num_prefixes=15)
|
||||
progress = out["progress"][0].numpy() # shape (T,), values in [0, 1]
|
||||
```
|
||||
|
||||
### Use the reward factory
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"topreward",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor writes normalised frames + task strings under the `observation.topreward.*` namespace; the model reads them in `compute_reward`.
|
||||
|
||||
### Offline dataset labeling
|
||||
|
||||
Mirror the SARM / Robometer RA-BC flow — write a `topreward_progress.parquet` once, then reuse it for training (RA-BC) and visualisation (overlay videos):
|
||||
|
||||
```bash
|
||||
# Fully dense per-frame labeling
|
||||
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||
--dataset-repo-id lerobot/libero_10_image \
|
||||
--device cuda
|
||||
|
||||
# Sparse-dense (15 anchors per episode, matches upstream)
|
||||
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||
--dataset-repo-id lerobot/libero_10_image \
|
||||
--num-prefixes 15 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
Then render the SARM-style progress overlay for any episode:
|
||||
|
||||
```bash
|
||||
uv run examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot/libero_10_image \
|
||||
--episode 0 \
|
||||
--progress-file topreward_progress.parquet \
|
||||
--gif
|
||||
```
|
||||
|
||||
## Publishing a named TOPReward configuration
|
||||
|
||||
Because TOPReward stores no weights of its own, "publishing a TOPReward model" amounts to writing the LeRobot `config.json` (≈ 1 KB) that pins the VLM id, prompt and reduction:
|
||||
|
||||
```python
|
||||
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
reduction="mean",
|
||||
fps=2.0,
|
||||
)
|
||||
TOPRewardModel(cfg).save_pretrained("./topreward-qwen3vl-8b")
|
||||
# Push the directory to the Hub via `huggingface-cli` or `HfApi.upload_folder`.
|
||||
```
|
||||
|
||||
Reloading restores the same configuration (no weight download for TOPReward itself; the VLM is re-fetched via `vlm_name`):
|
||||
|
||||
```python
|
||||
reloaded = TOPRewardModel.from_pretrained("./topreward-qwen3vl-8b")
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [TOPReward project page](https://topreward.github.io/webpage/)
|
||||
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
|
||||
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
|
||||
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2026topreward,
|
||||
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
|
||||
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
|
||||
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
|
||||
Krishna, Ranjay},
|
||||
journal={arXiv preprint arXiv:2602.19313},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
@@ -209,6 +209,7 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
@@ -21,11 +21,13 @@ from .factory import (
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
"TOPRewardConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
@@ -37,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm".
|
||||
"sarm", "topreward".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
@@ -53,6 +54,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "topreward":
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -69,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm".
|
||||
"reward_classifier", "sarm", "topreward".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -82,6 +87,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -162,6 +169,14 @@ def make_reward_pre_post_processors(
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, TOPRewardConfig):
|
||||
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
return make_topreward_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_topreward import TOPRewardConfig
|
||||
from .modeling_topreward import TOPRewardModel
|
||||
from .processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
__all__ = ["TOPRewardConfig", "TOPRewardModel", "make_topreward_pre_post_processors"]
|
||||
@@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame TOPReward progress curves for a LeRobot dataset.
|
||||
|
||||
This mirrors :mod:`lerobot.rewards.sarm.compute_rabc_weights` (and the
|
||||
ROBOMETER equivalent): it walks every episode in a dataset, runs the
|
||||
TOPReward zero-shot reward model, and writes a parquet file with one row
|
||||
per frame. The output uses the same schema SARM produces so existing
|
||||
consumers — :class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
|
||||
``progress_sparse``) and the SARM-style overlay script in
|
||||
``examples/dataset/create_progress_videos.py`` — work without modification.
|
||||
|
||||
TOPReward is zero-shot: there is no fine-tuned checkpoint to load. The
|
||||
``--reward-model-path`` argument is therefore optional and only used when
|
||||
you want to load a TOPReward LeRobot config (e.g. one published on the Hub
|
||||
that pins ``vlm_name`` and prompt knobs). Otherwise the default
|
||||
:class:`TOPRewardConfig` is used, which points at
|
||||
``Qwen/Qwen3-VL-8B-Instruct`` — the VLM is re-downloaded from the HF Hub
|
||||
on every run unless cached.
|
||||
|
||||
Parquet schema:
|
||||
+--------------------+---------+----------------------------------------+
|
||||
| column | dtype | meaning |
|
||||
+====================+=========+========================================+
|
||||
| ``index`` | int64 | global frame index |
|
||||
| ``episode_index`` | int64 | episode id |
|
||||
| ``frame_index`` | int64 | local within-episode index |
|
||||
| ``progress_sparse``| float32 | per-frame TOPReward progress in [0, 1] |
|
||||
| | | (RA-BC + overlay read this column) |
|
||||
+--------------------+---------+----------------------------------------+
|
||||
|
||||
Usage:
|
||||
# Full computation (one VLM forward per frame, slowest but most accurate)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image
|
||||
|
||||
# Sparse-dense mode: 15 anchor prefixes per episode, interpolated to
|
||||
# per-frame resolution. Matches upstream TOPReward ``num_samples=15``.
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-prefixes 15
|
||||
|
||||
# Use a different VLM backbone
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
|
||||
The output is written to the dataset's local cache directory as
|
||||
``topreward_progress.parquet`` (or to ``--output-path`` if provided).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "topreward_progress.parquet"
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def _frames_to_uint8_hwc(video: torch.Tensor) -> np.ndarray:
|
||||
"""Convert a ``(T, C, H, W)`` or ``(T, H, W, C)`` tensor to ``(T, H, W, C) uint8``.
|
||||
|
||||
Inlined here (rather than reusing the processor) so the labeling script
|
||||
can side-step the ``max_frames`` tail-crop and feed full trajectories
|
||||
to :meth:`TOPRewardModel.predict_curves`.
|
||||
"""
|
||||
if video.shape[1] in (1, 3):
|
||||
video = video.permute(0, 2, 3, 1)
|
||||
elif video.shape[-1] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
array = video.detach().cpu().numpy()
|
||||
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
|
||||
array = array * 255.0
|
||||
return np.clip(array, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def compute_topreward_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str | None = None,
|
||||
vlm_name: str | None = None,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
num_prefixes: int | None = None,
|
||||
fps: float | None = None,
|
||||
reduction: str | None = None,
|
||||
use_video_description: bool = False,
|
||||
) -> Path:
|
||||
"""Run TOPReward over a dataset and write per-frame progress.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: Hugging Face dataset repo id or local path.
|
||||
reward_model_path: Optional TOPReward LeRobot config repo / dir to
|
||||
load (a tiny ``config.json``). When ``None`` (default), a
|
||||
fresh :class:`TOPRewardConfig` is constructed from the CLI
|
||||
overrides.
|
||||
vlm_name: Override the VLM backbone (HF Hub id).
|
||||
output_path: Where to write the parquet. Defaults to
|
||||
``<dataset_root>/topreward_progress.parquet``.
|
||||
device: Device for the VLM.
|
||||
num_prefixes: Number of evenly-spaced anchor prefixes per episode.
|
||||
``None`` (default) = fully dense (one VLM forward per frame).
|
||||
Set to ``15`` to match upstream TOPReward ``num_samples=15``.
|
||||
fps: Override the config's ``fps``.
|
||||
reduction: Override the config's ``reduction`` (``"mean"`` / ``"sum"``).
|
||||
use_video_description: Override the config's ``use_video_description``.
|
||||
|
||||
Returns:
|
||||
Path to the written parquet file.
|
||||
"""
|
||||
if reward_model_path is not None:
|
||||
logging.info(f"Loading TOPReward config from: {reward_model_path}")
|
||||
model = TOPRewardModel.from_pretrained(reward_model_path)
|
||||
config = model.config
|
||||
# Apply CLI overrides on top of the loaded config.
|
||||
if vlm_name is not None and vlm_name != config.vlm_name:
|
||||
logging.info(f"Overriding vlm_name from config: {config.vlm_name} -> {vlm_name}")
|
||||
# vlm_name affects the loaded weights; reload from scratch.
|
||||
config.vlm_name = vlm_name
|
||||
config.device = device
|
||||
model = TOPRewardModel(config)
|
||||
else:
|
||||
config_kwargs: dict[str, Any] = {"device": device}
|
||||
if vlm_name is not None:
|
||||
config_kwargs["vlm_name"] = vlm_name
|
||||
if fps is not None:
|
||||
config_kwargs["fps"] = fps
|
||||
if reduction is not None:
|
||||
config_kwargs["reduction"] = reduction
|
||||
if use_video_description:
|
||||
config_kwargs["use_video_description"] = True
|
||||
config = TOPRewardConfig(**config_kwargs)
|
||||
logging.info(f"Constructing TOPReward with VLM: {config.vlm_name}")
|
||||
model = TOPRewardModel(config)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
image_key = config.image_key
|
||||
frames_key = f"{TOPREWARD_FEATURE_PREFIX}frames"
|
||||
task_batch_key = f"{TOPREWARD_FEATURE_PREFIX}task"
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(range(dataset.num_episodes), desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
# Read the whole episode into one (N, C, H, W) tensor and convert
|
||||
# to (N, H, W, C) uint8 — same format ``TOPREWARD_FEATURE_PREFIX.frames``
|
||||
# expects. We deliberately bypass the encoder step here so its
|
||||
# ``max_frames`` tail-crop doesn't clip the prefix sweep.
|
||||
ep_video = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
ep_frames_uint8 = _frames_to_uint8_hwc(ep_video)
|
||||
|
||||
batch = {frames_key: [ep_frames_uint8], task_batch_key: [task]}
|
||||
out = model.predict_curves(batch, num_prefixes=num_prefixes)
|
||||
per_frame = out["progress"][0, :num_frames].cpu().numpy()
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
# Same column name SARM uses so RABCWeights + the overlay
|
||||
# script read TOPReward's output without per-model branching.
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
# Persist provenance metadata: the LeRobot path (if any) and the VLM id.
|
||||
schema_metadata: dict[bytes, bytes] = {b"vlm_name": config.vlm_name.encode()}
|
||||
if reward_model_path is not None:
|
||||
schema_metadata[b"reward_model_path"] = reward_model_path.encode()
|
||||
table = table.replace_schema_metadata(schema_metadata)
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame TOPReward progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with the default Qwen3-VL-8B-Instruct backbone
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image
|
||||
|
||||
# Sparse-dense mode (matches upstream TOPReward num_samples=15)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-prefixes 15
|
||||
|
||||
# Use a smaller VLM
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repo id or local path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional TOPReward LeRobot config (repo id or local dir). "
|
||||
"Falls back to a fresh TOPRewardConfig if unset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vlm-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override the VLM backbone (HF Hub id).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output parquet path. Defaults to <dataset_root>/topreward_progress.parquet.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use (default: cuda).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prefixes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Evenly-spaced anchor prefixes per episode. None = fully dense "
|
||||
"(one VLM forward per frame). 15 matches upstream TOPReward.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Override TOPRewardConfig.fps (frames per second for the Qwen video processor).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reduction",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["mean", "sum"],
|
||||
help="Override TOPRewardConfig.reduction.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-video-description",
|
||||
action="store_true",
|
||||
help="Generate an instruction-agnostic video description and prepend "
|
||||
"it as context before scoring (doubles VLM calls per prefix).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload the progress file to the dataset repo on HuggingFace Hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
output_path = compute_topreward_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=args.reward_model_path,
|
||||
vlm_name=args.vlm_name,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
num_prefixes=args.num_prefixes,
|
||||
fps=args.fps,
|
||||
reduction=args.reduction,
|
||||
use_video_description=args.use_video_description,
|
||||
)
|
||||
|
||||
print(f"\nTOPReward progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Default prompt scaffolding from the upstream TOPReward paper / reference
|
||||
# implementation (``QwenClient.compute_instruction_reward``). The prompt
|
||||
# computes the log-likelihood of the suffix ``f"{instruction} ... True"``
|
||||
# given the video, then reduces those token log-probs to a scalar reward.
|
||||
DEFAULT_PROMPT_PREFIX = (
|
||||
"The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
)
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE = (
|
||||
"{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("topreward")
|
||||
@dataclass
|
||||
class TOPRewardConfig(RewardModelConfig):
|
||||
"""Configuration for the TOPReward zero-shot reward model.
|
||||
|
||||
TOPReward is **zero-shot**: it has no learnable parameters of its own.
|
||||
The "model" is a generic vision-language model (default
|
||||
``Qwen/Qwen3-VL-8B-Instruct``) used with a fixed prompt to extract
|
||||
token log-probabilities as a reward signal. There is therefore no
|
||||
fine-tuned checkpoint to host: ``pretrained_path`` is unused at
|
||||
runtime — the model identity is :attr:`vlm_name` (an HF Hub id).
|
||||
|
||||
Args:
|
||||
vlm_name: Hugging Face Hub id of the underlying VLM. Must be a
|
||||
Qwen3-VL family model (the only client implemented in this
|
||||
LeRobot port).
|
||||
torch_dtype: Torch dtype name passed to the VLM loader
|
||||
(``"auto"``, ``"bfloat16"``, ``"float16"``, ...).
|
||||
attn_implementation: ``transformers`` attention implementation
|
||||
(e.g. ``"flash_attention_2"``, ``"sdpa"``). Defaults to
|
||||
``None`` so the upstream picks the best available.
|
||||
image_key: Observation key that holds the trajectory frames.
|
||||
task_key: Complementary-data key that holds the task instruction.
|
||||
default_task: Fallback instruction when ``task_key`` is absent.
|
||||
max_frames: Cap on the number of frames fed to the VLM per
|
||||
sample. ``None`` = use all frames.
|
||||
fps: Frames-per-second metadata for the Qwen video processor.
|
||||
prompt_prefix: Text shown to the VLM right after the video and
|
||||
before the suffix template.
|
||||
prompt_suffix_template: Suffix appended after ``prompt_prefix``.
|
||||
Must contain ``{instruction}``; the VLM scores the
|
||||
log-likelihood of the tokens that follow the prefix.
|
||||
add_chat_template: If ``True``, wrap the full prompt with the
|
||||
tokenizer's chat template before tokenisation (matches
|
||||
upstream ``add_chat_template=True``).
|
||||
use_video_description: If ``True``, make an extra VLM call to
|
||||
produce an instruction-agnostic video description and prepend
|
||||
it as additional context. Doubles inference cost but avoids
|
||||
circular grounding when the instruction names objects shown
|
||||
in frames.
|
||||
reduction: Reduction over per-token log-probs of the suffix
|
||||
tokens (``"mean"`` or ``"sum"``).
|
||||
success_threshold: Optional log-prob threshold. If finite,
|
||||
:meth:`TOPRewardModel.compute_reward` returns
|
||||
``(reward > success_threshold).float()`` instead of the raw
|
||||
log-prob.
|
||||
max_input_length: Hard limit on the total tokenized input length;
|
||||
samples that exceed it raise a ``ValueError``.
|
||||
"""
|
||||
|
||||
# Path to a local LeRobot dir or HF repo that holds a ``config.json``
|
||||
# snapshot of this TOPRewardConfig. The VLM weights themselves are
|
||||
# always identified by ``vlm_name``.
|
||||
pretrained_path: str | None = None
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
torch_dtype: str = "auto"
|
||||
attn_implementation: str | None = None
|
||||
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
use_video_description: bool = False
|
||||
|
||||
reduction: str = "mean"
|
||||
success_threshold: float = float("-inf")
|
||||
max_input_length: int = 32768
|
||||
|
||||
license: str | None = "mit" # matches upstream TOPReward
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.reduction not in {"mean", "sum"}:
|
||||
raise ValueError(f"reduction must be 'mean' or 'sum', got {self.reduction!r}")
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be > 0, got {self.fps}")
|
||||
if "{instruction}" not in self.prompt_suffix_template:
|
||||
raise ValueError(
|
||||
"prompt_suffix_template must contain `{instruction}` so the model "
|
||||
"scores the log-likelihood of the task suffix."
|
||||
)
|
||||
if self.max_input_length <= 0:
|
||||
raise ValueError(f"max_input_length must be > 0, got {self.max_input_length}")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("reward", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"TOPReward requires image input feature {self.image_key!r}")
|
||||
@@ -0,0 +1,563 @@
|
||||
# Copyright 2026 Shirui Chen, Cole Harrison, Ying-Chun Lee, Angela Jin Yang,
|
||||
# Zhongzheng Ren, Lillian J. Ratliff, Jiafei Duan, Dieter Fox, Ranjay Krishna
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics.
|
||||
|
||||
Paper: https://arxiv.org/abs/2602.19313
|
||||
Project: https://topreward.github.io/webpage/
|
||||
Original code: https://github.com/TOPReward/TOPReward
|
||||
Backbone: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct (default)
|
||||
|
||||
TOPReward is a **zero-shot** reward model: it has no fine-tuned weights of
|
||||
its own. Given a video trajectory and a task instruction, it asks an
|
||||
off-the-shelf VLM how likely the instruction is, conditioned on the video,
|
||||
and returns that log-likelihood as the reward signal.
|
||||
|
||||
Inference recipe:
|
||||
|
||||
1. Build a chat-style prompt:
|
||||
``[video(frames, fps), text=prompt_prefix, text="{instruction} ... True"]``
|
||||
2. Forward the full token sequence through the VLM.
|
||||
3. Mask all but the final token with ``-100`` (``prompt_length = input_len - 1``,
|
||||
mirrored from upstream). After the standard causal-LM next-token shift, this
|
||||
isolates the single position where the model predicts the literal ``"True"``
|
||||
that ends the prompt — the binary "is the instruction true given the video?"
|
||||
answer.
|
||||
4. Read that token's log-probability from the logits and reduce it (mean or sum
|
||||
— equivalent for a single token, kept for API parity with upstream) into a
|
||||
scalar reward.
|
||||
|
||||
This LeRobot port is **inference-only and not trainable** — :meth:`forward`
|
||||
is intentionally inherited from :class:`PreTrainedRewardModel` and raises
|
||||
``NotImplementedError``, making :attr:`PreTrainedRewardModel.is_trainable`
|
||||
return ``False``.
|
||||
|
||||
Because the VLM weights live on the Hugging Face Hub under their canonical
|
||||
id (``Qwen/Qwen3-VL-8B-Instruct`` etc.) and TOPReward never modifies them,
|
||||
:meth:`_save_pretrained` and :meth:`from_pretrained` are overridden so a
|
||||
TOPReward LeRobot "checkpoint" is a single ``config.json`` (the VLM is
|
||||
re-fetched from the Hub at load time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None # type: ignore[assignment]
|
||||
Qwen3VLForConditionalGeneration = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="TOPRewardModel")
|
||||
|
||||
_TRUE_ANSWER = "True"
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype | str:
|
||||
"""Resolve a torch dtype name; ``"auto"`` is passed through verbatim."""
|
||||
if name == "auto":
|
||||
return "auto"
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
|
||||
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
|
||||
if frames.ndim != 4:
|
||||
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
|
||||
if frames.dtype != np.uint8:
|
||||
frames = np.clip(frames, 0, 255).astype(np.uint8)
|
||||
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
|
||||
|
||||
|
||||
def minmax_normalize_rewards(rewards: list[float] | np.ndarray) -> np.ndarray:
|
||||
"""Min-max normalise raw log-prob rewards into ``[0, 1]``.
|
||||
|
||||
Matches upstream ``QwenClient.normalize_rewards(rewards, method="minmax")``:
|
||||
a single-element input maps to ``[1.0]`` (no information to scale), and a
|
||||
flat input (``max == min``) maps to all-ones.
|
||||
"""
|
||||
rewards_arr = np.asarray(rewards, dtype=np.float64)
|
||||
if rewards_arr.size == 0:
|
||||
return rewards_arr.astype(np.float32)
|
||||
if rewards_arr.size == 1:
|
||||
return np.array([1.0], dtype=np.float32)
|
||||
r_min, r_max = rewards_arr.min(), rewards_arr.max()
|
||||
if r_max == r_min:
|
||||
return np.ones_like(rewards_arr, dtype=np.float32)
|
||||
return ((rewards_arr - r_min) / (r_max - r_min)).astype(np.float32)
|
||||
|
||||
|
||||
class TOPRewardModel(PreTrainedRewardModel):
|
||||
"""TOPReward zero-shot reward model."""
|
||||
|
||||
name = "topreward"
|
||||
config_class = TOPRewardConfig
|
||||
|
||||
def __init__(self, config: TOPRewardConfig) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
require_package("qwen-vl-utils", extra="topreward", import_name="qwen_vl_utils")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
model_kwargs: dict[str, Any] = {"dtype": torch_dtype, "trust_remote_code": True}
|
||||
if config.attn_implementation is not None:
|
||||
model_kwargs["attn_implementation"] = config.attn_implementation
|
||||
|
||||
# TOPReward is zero-shot: load the VLM as-is from the Hub. No
|
||||
# weights of our own, no embedding resize, no head wiring.
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(config.vlm_name, **model_kwargs)
|
||||
self.processor = AutoProcessor.from_pretrained(config.vlm_name, trust_remote_code=True)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Return one log-prob reward per sample in the batch.
|
||||
|
||||
Expects a batch produced by :class:`TOPRewardEncoderProcessorStep`:
|
||||
``observation[f"{TOPREWARD_FEATURE_PREFIX}frames"]`` is a list of
|
||||
``(T, H, W, C) uint8`` numpy arrays (one per sample) and
|
||||
``observation[f"{TOPREWARD_FEATURE_PREFIX}task"]`` is a list of
|
||||
task strings of the same length.
|
||||
"""
|
||||
frames_per_sample, tasks = self._unpack_batch(batch)
|
||||
rewards = [
|
||||
self._compute_log_prob_reward(frames, task)
|
||||
for frames, task in zip(frames_per_sample, tasks, strict=True)
|
||||
]
|
||||
out = torch.as_tensor(rewards, dtype=torch.float32)
|
||||
if np.isfinite(self.config.success_threshold):
|
||||
out = (out > self.config.success_threshold).float()
|
||||
return out.to(self.config.device or "cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_curves(
|
||||
self,
|
||||
batch: dict[str, Any],
|
||||
*,
|
||||
num_prefixes: int | None = None,
|
||||
) -> dict[str, Tensor]:
|
||||
"""Per-sample dense progress curves over prefixes ``[0, t]``.
|
||||
|
||||
Mirrors upstream ``compute_instruction_rewards_for_prefixes``: for
|
||||
each sample we run one VLM forward per prefix length and read the
|
||||
log-prob reward at that prefix. Raw log-probs are then min-max
|
||||
normalised per-trajectory to ``[0, 1]``. Because trajectories
|
||||
within a batch can have different lengths, the returned
|
||||
``progress`` tensor is right-padded with ``NaN`` to the longest
|
||||
trajectory in the batch.
|
||||
|
||||
Args:
|
||||
batch: Same input as :meth:`compute_reward`.
|
||||
num_prefixes: How many evenly-spaced prefix lengths to score
|
||||
per trajectory. ``None`` (default) uses every prefix
|
||||
length ``[1, N]`` → fully dense, ``N`` VLM forwards per
|
||||
trajectory. Pass a smaller integer (e.g. ``15``, the
|
||||
upstream default) for sparse-dense scoring with linear
|
||||
interpolation between anchors.
|
||||
|
||||
Returns:
|
||||
Dict with one float32 CPU tensor:
|
||||
|
||||
- ``progress``: ``(B, T_max)`` — per-frame progress in
|
||||
``[0, 1]`` (min-max normalised log-prob curve), padded with
|
||||
``NaN``.
|
||||
"""
|
||||
if num_prefixes is not None and num_prefixes < 1:
|
||||
raise ValueError(f"num_prefixes must be >= 1 or None, got {num_prefixes}")
|
||||
|
||||
frames_per_sample, tasks = self._unpack_batch(batch)
|
||||
curves: list[np.ndarray] = []
|
||||
max_len = 0
|
||||
for frames, task in zip(frames_per_sample, tasks, strict=True):
|
||||
num_frames = int(frames.shape[0])
|
||||
if num_frames == 0:
|
||||
curves.append(np.zeros(0, dtype=np.float32))
|
||||
continue
|
||||
|
||||
if num_prefixes is None or num_prefixes >= num_frames:
|
||||
anchor_lengths = np.arange(1, num_frames + 1, dtype=np.int64)
|
||||
else:
|
||||
# Match upstream: linspace from 1 to N, dedupe (rounding
|
||||
# collisions for short trajectories), sort ascending.
|
||||
anchor_lengths = np.unique(np.linspace(1, num_frames, num_prefixes).round().astype(np.int64))
|
||||
|
||||
raw_rewards = [self._compute_log_prob_reward(frames[:length], task) for length in anchor_lengths]
|
||||
normalized_at_anchors = minmax_normalize_rewards(raw_rewards)
|
||||
|
||||
# Linear interpolation back to per-frame resolution when
|
||||
# `num_prefixes < num_frames`.
|
||||
if anchor_lengths.shape[0] == num_frames:
|
||||
per_frame = normalized_at_anchors
|
||||
else:
|
||||
per_frame = np.interp(
|
||||
np.arange(1, num_frames + 1, dtype=np.float64),
|
||||
anchor_lengths.astype(np.float64),
|
||||
normalized_at_anchors.astype(np.float64),
|
||||
).astype(np.float32)
|
||||
|
||||
curves.append(per_frame)
|
||||
max_len = max(max_len, num_frames)
|
||||
|
||||
padded = np.full((len(curves), max_len), np.nan, dtype=np.float32)
|
||||
for i, curve in enumerate(curves):
|
||||
padded[i, : curve.shape[0]] = curve
|
||||
return {"progress": torch.from_numpy(padded)}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Save / load — VLM weights are not stored in our checkpoint
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save ``config.json`` only.
|
||||
|
||||
TOPReward has no fine-tuned weights of its own — the VLM is
|
||||
identified by :attr:`TOPRewardConfig.vlm_name` and lives on the
|
||||
Hugging Face Hub under that id. Writing the VLM into a
|
||||
``model.safetensors`` here would just duplicate ~16 GB of Qwen
|
||||
weights under our org for no benefit.
|
||||
"""
|
||||
self.config._save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False, # accepted for API parity; unused
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Load a TOPReward configuration and instantiate the wrapped VLM.
|
||||
|
||||
Two modes:
|
||||
|
||||
- Local directory containing ``config.json``: read the config and
|
||||
rebuild the model. The VLM is re-fetched from the Hub via
|
||||
:attr:`TOPRewardConfig.vlm_name`.
|
||||
- HF Hub repo id: download just ``config.json``, same as above.
|
||||
"""
|
||||
del strict # TOPReward has no weights of its own to (strictly) load.
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(config, TOPRewardConfig):
|
||||
raise TypeError(
|
||||
f"Expected a TOPRewardConfig, got {type(config).__name__}. Make sure "
|
||||
f"`pretrained_name_or_path={pretrained_name_or_path!r}` points at a "
|
||||
"TOPReward checkpoint."
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
if not os.path.isdir(model_id):
|
||||
# Validate that the remote repo at least contains a TOPReward config.json
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
instance = cls(config, **kwargs)
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
def push_model_to_hub(self, cfg: TrainPipelineConfig):
|
||||
"""Push the TOPReward ``config.json`` + model card to the Hub.
|
||||
|
||||
Skips the safetensors upload — the wrapped VLM is identified by
|
||||
``vlm_name`` and we never modify it.
|
||||
"""
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
saved_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.config._save_pretrained(saved_path)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path)
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload TOPReward config and readme",
|
||||
allow_patterns=["*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log", "*.safetensors"],
|
||||
)
|
||||
|
||||
logger.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def _unpack_batch(self, batch: dict[str, Any]) -> tuple[list[np.ndarray], list[str]]:
|
||||
frames_key = f"{TOPREWARD_FEATURE_PREFIX}frames"
|
||||
task_key = f"{TOPREWARD_FEATURE_PREFIX}task"
|
||||
if frames_key not in batch or task_key not in batch:
|
||||
raise KeyError(
|
||||
"TOPReward batch missing pre-encoded inputs (expected "
|
||||
f"`{frames_key}` and `{task_key}`). Make sure the "
|
||||
"TOPRewardEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
frames_per_sample = list(batch[frames_key])
|
||||
tasks = list(batch[task_key])
|
||||
if len(frames_per_sample) != len(tasks):
|
||||
raise ValueError(
|
||||
f"frames batch size ({len(frames_per_sample)}) does not match task batch size ({len(tasks)})"
|
||||
)
|
||||
return frames_per_sample, tasks
|
||||
|
||||
@torch.no_grad()
|
||||
def _compute_log_prob_reward(self, frames: np.ndarray, instruction: str) -> float:
|
||||
"""Compute the log-likelihood of the final answer token given the prompt.
|
||||
|
||||
Port of ``QwenClient.compute_instruction_reward`` (the upstream
|
||||
TOPReward implementation), stripped of the
|
||||
:class:`InstructionRewardResult` metadata wrapper we don't need.
|
||||
Returns ``log P(final_token | video + prompt + instruction)`` — by
|
||||
default the final token is the literal ``"True"`` that closes the
|
||||
suffix template, which is the binary "is the instruction satisfied"
|
||||
signal the paper describes.
|
||||
"""
|
||||
device = next(self.model.parameters()).device
|
||||
pil_frames = _frames_to_pil(frames)
|
||||
|
||||
if self.config.use_video_description:
|
||||
description = self._generate_object_state_reasoning(pil_frames)
|
||||
prompt_text = (
|
||||
f"{description} Therefore given the above description and the "
|
||||
"video, the video shows a robot manipulation trajectory that "
|
||||
"**completes** the following instruction: "
|
||||
)
|
||||
else:
|
||||
prompt_text = self.config.prompt_prefix
|
||||
|
||||
eos_token = self.processor.tokenizer.eos_token
|
||||
instruction_suffix = self.config.prompt_suffix_template.format(instruction=instruction)
|
||||
|
||||
# Two prompt assembly modes match the upstream:
|
||||
#
|
||||
# - ``add_chat_template=True``: wrap the FULL prompt (including
|
||||
# instruction) with the chat template, then append the literal
|
||||
# ``"True"`` token outside the template.
|
||||
# - ``add_chat_template=False``: apply the chat template to the
|
||||
# video+prefix only (no generation prompt), strip the trailing
|
||||
# EOS, then concatenate the literal instruction suffix.
|
||||
if self.config.add_chat_template:
|
||||
# Suffix excluding the trailing "True" — we want "True" to be
|
||||
# the scored token, not part of the template's user turn.
|
||||
suffix_for_template = instruction_suffix.removesuffix(_TRUE_ANSWER).rstrip()
|
||||
templated_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": pil_frames, "fps": self.config.fps},
|
||||
{"type": "text", "text": f"{prompt_text}{suffix_for_template}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self.processor.apply_chat_template(
|
||||
templated_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
full_text = f"{prompt_chat}{_TRUE_ANSWER}"
|
||||
image_inputs, video_inputs = self._process_vision_info(templated_messages)
|
||||
else:
|
||||
user_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": pil_frames, "fps": self.config.fps},
|
||||
{"type": "text", "text": prompt_text},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self.processor.apply_chat_template(
|
||||
user_messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
if eos_token is not None:
|
||||
prompt_chat = prompt_chat.split(eos_token)[0]
|
||||
full_text = f"{prompt_chat}{instruction_suffix}"
|
||||
image_inputs, video_inputs = self._process_vision_info(user_messages)
|
||||
|
||||
inputs = self.processor(
|
||||
text=[full_text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
input_len = int(inputs["input_ids"].shape[-1])
|
||||
if input_len > self.config.max_input_length:
|
||||
raise ValueError(
|
||||
f"TOPReward input length {input_len} exceeds max_input_length "
|
||||
f"{self.config.max_input_length}; lower `max_frames` or raise `max_input_length`."
|
||||
)
|
||||
|
||||
labels = inputs["input_ids"].clone()
|
||||
# Mask everything except the very last token. ``prompt_length = input_len - 1``
|
||||
# mirrors upstream ``QwenClient.compute_instruction_reward``; after the
|
||||
# causal-LM next-token shift below this isolates exactly one position —
|
||||
# the prediction of the literal ``"True"`` that closes ``prompt_suffix_template``.
|
||||
# The resulting reward is therefore ``log P("True" | video + prompt + instruction)``.
|
||||
prompt_length = input_len - 1
|
||||
labels[:, :prompt_length] = -100
|
||||
if "attention_mask" in inputs:
|
||||
labels = labels.masked_fill(inputs["attention_mask"] == 0, -100)
|
||||
|
||||
self.model.eval()
|
||||
outputs = self.model(**inputs, labels=labels)
|
||||
|
||||
logits = outputs.logits[:, :-1, :]
|
||||
target_labels = labels[:, 1:]
|
||||
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
mask = target_labels != -100
|
||||
safe_targets = target_labels.masked_fill(~mask, 0)
|
||||
token_log_probs = log_probs.gather(-1, safe_targets.unsqueeze(-1)).squeeze(-1)
|
||||
masked_log_probs = token_log_probs[mask]
|
||||
if masked_log_probs.numel() == 0:
|
||||
raise RuntimeError(
|
||||
"TOPReward could not isolate any suffix tokens to score. Check that "
|
||||
"`prompt_suffix_template` produces at least one tokenised character."
|
||||
)
|
||||
|
||||
# ``mean`` vs ``sum`` are equivalent for a single scored token but the
|
||||
# knob is kept for API parity with upstream (and for forward-compat with
|
||||
# any future variant that scores more than the final answer token).
|
||||
if self.config.reduction == "sum":
|
||||
reward = masked_log_probs.sum().item()
|
||||
else: # mean
|
||||
reward = masked_log_probs.mean().item()
|
||||
return float(reward)
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_object_state_reasoning(self, pil_frames: list[Image.Image]) -> str:
|
||||
"""Instruction-agnostic trajectory description (upstream
|
||||
``QwenClient.generate_object_state_reasoning``). Used when
|
||||
:attr:`TOPRewardConfig.use_video_description` is ``True``.
|
||||
"""
|
||||
device = next(self.model.parameters()).device
|
||||
user_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": pil_frames, "fps": self.config.fps},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe the robot manipulation trajectory in this video:",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self.processor.apply_chat_template(
|
||||
user_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = self._process_vision_info(user_messages)
|
||||
inputs = self.processor(
|
||||
text=[prompt_chat],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
self.model.eval()
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
do_sample=False,
|
||||
)
|
||||
response = self.processor.batch_decode(
|
||||
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
prompt_decoded = self.processor.batch_decode(
|
||||
inputs["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
if response.startswith(prompt_decoded):
|
||||
return response[len(prompt_decoded) :].strip()
|
||||
return response.strip()
|
||||
|
||||
@staticmethod
|
||||
def _process_vision_info(messages: list[dict[str, Any]]) -> tuple[Any, Any]:
|
||||
"""Thin wrapper around ``qwen_vl_utils.process_vision_info``.
|
||||
|
||||
Kept as a method so tests can monkey-patch it without depending on
|
||||
the import-time presence of ``qwen_vl_utils``.
|
||||
"""
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
return cast(tuple[Any, Any], process_vision_info(messages))
|
||||
@@ -0,0 +1,200 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward pre/post processing pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
# Namespace for TOPReward's pre-encoded observation tensors written by the
|
||||
# processor and consumed by the model. Keys: ``frames`` (one ``(T,H,W,C)``
|
||||
# uint8 numpy array per sample) and ``task`` (one string per sample).
|
||||
TOPREWARD_FEATURE_PREFIX = f"{OBS_PREFIX}topreward."
|
||||
|
||||
|
||||
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
|
||||
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array.
|
||||
|
||||
Mirrors the Robometer helper: accepts ``(T, C, H, W)`` or ``(T, H, W, C)``
|
||||
layouts, rescales floats in ``[0, 1]`` to ``[0, 255]``, clips values
|
||||
outside the uint8 range and tail-crops to ``max_frames``.
|
||||
"""
|
||||
if max_frames is not None:
|
||||
video = video[-max_frames:]
|
||||
if video.shape[1] in (1, 3):
|
||||
video = video.permute(0, 2, 3, 1)
|
||||
elif video.shape[-1] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
array = video.detach().cpu().numpy()
|
||||
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
|
||||
array = array * 255.0
|
||||
return np.clip(array, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("TOPReward expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"TOPReward task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="topreward_encoder")
|
||||
class TOPRewardEncoderProcessorStep(ProcessorStep):
|
||||
"""Normalise raw frames + task into TOPReward-namespaced observation entries.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes:
|
||||
|
||||
- ``observation[f"{TOPREWARD_FEATURE_PREFIX}frames"]``: list of
|
||||
``(T, H, W, C) uint8`` numpy arrays, one per sample.
|
||||
- ``observation[f"{TOPREWARD_FEATURE_PREFIX}task"]``: list of strings,
|
||||
one per sample.
|
||||
|
||||
The actual chat-template / tokenisation happens model-side because
|
||||
TOPReward's reward extraction needs the tokenizer to know the
|
||||
prompt/suffix split (label masking on suffix tokens only).
|
||||
"""
|
||||
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("TOPRewardEncoderProcessorStep requires an observation dict")
|
||||
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"TOPReward expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
if tensor.ndim == 4:
|
||||
tensor = tensor.unsqueeze(1)
|
||||
elif tensor.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected TOPReward frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
batch_size = tensor.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
frames_per_sample = [
|
||||
_video_to_numpy(tensor[i], max_frames=self.max_frames) for i in range(batch_size)
|
||||
]
|
||||
|
||||
new_observation = dict(observation)
|
||||
new_observation[f"{TOPREWARD_FEATURE_PREFIX}frames"] = frames_per_sample
|
||||
new_observation[f"{TOPREWARD_FEATURE_PREFIX}task"] = list(tasks)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
}
|
||||
|
||||
|
||||
def make_topreward_pre_post_processors(
|
||||
config: TOPRewardConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that normalises frames + task for the TOPReward model.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs TOPReward's
|
||||
encoder, and moves any tensor entries to the configured device. The
|
||||
postprocessor is the identity since TOPReward outputs a single reward
|
||||
tensor.
|
||||
"""
|
||||
del dataset_stats # TOPReward's VLM handles its own normalisation.
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TOPRewardEncoderProcessorStep(
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -13,6 +13,8 @@
|
||||
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
|
||||
{% elif model_name == "sarm" %}
|
||||
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
|
||||
{% elif model_name == "topreward" %}
|
||||
TOPReward is a **zero-shot** reward model that extracts token log-probabilities from an off-the-shelf vision-language model (default Qwen3-VL) as a reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood of the instruction being true, with no fine-tuning required.
|
||||
{% else %}
|
||||
_Reward model type not recognized — please update this template._
|
||||
{% endif %}
|
||||
|
||||
@@ -0,0 +1,421 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the TOPReward reward model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.factory import get_reward_model_class, make_reward_model_config
|
||||
from lerobot.rewards.topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.modeling_topreward import minmax_normalize_rewards
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
"""Minimal tokenizer surface used by ``TOPRewardModel._compute_log_prob_reward``."""
|
||||
|
||||
eos_token = "<|endoftext|>"
|
||||
|
||||
|
||||
class _FakeProcessor:
|
||||
"""Stand-in for the Qwen ``AutoProcessor`` returned by ``from_pretrained``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = _FakeTokenizer()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
|
||||
class _FakeQwenModel(torch.nn.Module):
|
||||
"""Stand-in for ``Qwen3VLForConditionalGeneration``.
|
||||
|
||||
Provides the minimum surface ``TOPRewardModel`` touches at construction
|
||||
time (a ``parameters()`` iterator for device inference). Actual
|
||||
``_compute_log_prob_reward`` calls are bypassed by monkey-patching the
|
||||
method directly in the tests, so we never invoke ``self.model(...)``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._param = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
|
||||
def _patch_build(monkeypatch) -> None:
|
||||
"""Stub out HF AutoX so TOPReward construction is cheap and offline."""
|
||||
from lerobot.rewards.topreward import modeling_topreward
|
||||
|
||||
monkeypatch.setattr(modeling_topreward, "Qwen3VLForConditionalGeneration", _FakeQwenModel)
|
||||
monkeypatch.setattr(modeling_topreward, "AutoProcessor", _FakeProcessor)
|
||||
|
||||
|
||||
def _make_batch(frames: list[np.ndarray], tasks: list[str]) -> dict[str, list]:
|
||||
"""Build a ``compute_reward``-ready batch using TOPReward's namespaced keys."""
|
||||
return {
|
||||
f"{TOPREWARD_FEATURE_PREFIX}frames": frames,
|
||||
f"{TOPREWARD_FEATURE_PREFIX}task": tasks,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry + factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_topreward_config_registered():
|
||||
assert "topreward" in RewardModelConfig.get_known_choices()
|
||||
assert RewardModelConfig.get_choice_class("topreward") is TOPRewardConfig
|
||||
assert isinstance(make_reward_model_config("topreward", device="cpu"), TOPRewardConfig)
|
||||
|
||||
|
||||
def test_topreward_factory_returns_in_tree_class():
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
assert get_reward_model_class("topreward") is TOPRewardModel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_topreward_config_rejects_bad_reduction():
|
||||
with pytest.raises(ValueError, match="reduction must be"):
|
||||
TOPRewardConfig(device="cpu", reduction="median")
|
||||
|
||||
|
||||
def test_topreward_config_rejects_zero_max_frames():
|
||||
with pytest.raises(ValueError, match="max_frames must be >= 1"):
|
||||
TOPRewardConfig(device="cpu", max_frames=0)
|
||||
|
||||
|
||||
def test_topreward_config_rejects_non_positive_fps():
|
||||
with pytest.raises(ValueError, match="fps must be > 0"):
|
||||
TOPRewardConfig(device="cpu", fps=0.0)
|
||||
|
||||
|
||||
def test_topreward_config_rejects_suffix_without_instruction_placeholder():
|
||||
with pytest.raises(ValueError, match=r"\{instruction\}"):
|
||||
TOPRewardConfig(device="cpu", prompt_suffix_template="no placeholder here")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# minmax_normalize_rewards — pure math helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_minmax_normalize_rewards_maps_min_and_max_to_zero_and_one():
|
||||
values = minmax_normalize_rewards([-3.0, -1.0, 0.0, -2.0])
|
||||
assert values.shape == (4,)
|
||||
assert values[0] == pytest.approx(0.0)
|
||||
assert values[2] == pytest.approx(1.0)
|
||||
# Monotonicity preserved within the input range.
|
||||
assert values[3] == pytest.approx(1.0 / 3.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_minmax_normalize_rewards_handles_singleton_and_flat_inputs():
|
||||
# Single element -> mapped to 1.0 (no information to scale).
|
||||
assert minmax_normalize_rewards([42.0]).tolist() == [1.0]
|
||||
# All-equal values -> all ones (avoid divide-by-zero).
|
||||
assert minmax_normalize_rewards([0.5, 0.5, 0.5]).tolist() == [1.0, 1.0, 1.0]
|
||||
|
||||
|
||||
def test_minmax_normalize_rewards_empty_input_returns_empty_array():
|
||||
out = minmax_normalize_rewards([])
|
||||
assert out.shape == (0,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_reward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_returns_one_scalar_per_sample(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
captured = []
|
||||
|
||||
def fake_log_prob(self, frames, instruction): # noqa: ARG002
|
||||
captured.append((frames.shape, instruction))
|
||||
return -1.5
|
||||
|
||||
monkeypatch.setattr(TOPRewardModel, "_compute_log_prob_reward", fake_log_prob)
|
||||
|
||||
frames_a = np.zeros((4, 8, 8, 3), dtype=np.uint8)
|
||||
frames_b = np.zeros((6, 8, 8, 3), dtype=np.uint8)
|
||||
batch = _make_batch([frames_a, frames_b], ["pick the cube", "open the drawer"])
|
||||
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert rewards.shape == (2,)
|
||||
assert rewards.dtype == torch.float32
|
||||
assert torch.allclose(rewards, torch.tensor([-1.5, -1.5]))
|
||||
# `_compute_log_prob_reward` was called once per sample with the right tasks.
|
||||
assert [task for _, task in captured] == ["pick the cube", "open the drawer"]
|
||||
assert [shape[0] for shape, _ in captured] == [4, 6]
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_applies_success_threshold(monkeypatch):
|
||||
"""When ``success_threshold`` is finite, the model returns binary success
|
||||
instead of the raw log-prob — useful as a drop-in success detector."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu", success_threshold=-2.0)
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
rewards_in = iter([-1.5, -3.0]) # first above threshold, second below
|
||||
monkeypatch.setattr(
|
||||
TOPRewardModel,
|
||||
"_compute_log_prob_reward",
|
||||
lambda _self, _frames, _instr: next(rewards_in),
|
||||
)
|
||||
|
||||
frames = [np.zeros((2, 8, 8, 3), dtype=np.uint8), np.zeros((2, 8, 8, 3), dtype=np.uint8)]
|
||||
rewards = model.compute_reward(_make_batch(frames, ["task", "task"]))
|
||||
|
||||
assert torch.equal(rewards, torch.tensor([1.0, 0.0]))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_errors_when_inputs_missing(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
with pytest.raises(KeyError, match=r"observation\.topreward\."):
|
||||
model.compute_reward({})
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_errors_when_batch_sizes_mismatch(monkeypatch):
|
||||
"""frames and task lists must have matching lengths — a stale processor
|
||||
that produces only one task for a multi-sample batch should surface as
|
||||
an explicit error, not a silent zip truncation."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
monkeypatch.setattr(
|
||||
TOPRewardModel,
|
||||
"_compute_log_prob_reward",
|
||||
lambda _self, _frames, _instr: 0.0,
|
||||
)
|
||||
|
||||
frames = [np.zeros((2, 8, 8, 3), dtype=np.uint8), np.zeros((2, 8, 8, 3), dtype=np.uint8)]
|
||||
with pytest.raises(ValueError, match="task batch size"):
|
||||
model.compute_reward(_make_batch(frames, ["only one task"]))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# predict_curves
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_predict_curves_runs_one_forward_per_prefix(monkeypatch):
|
||||
"""``predict_curves`` must call the VLM once per prefix length per
|
||||
trajectory and write min-max-normalised values back into the curve."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
# Simulate a strictly increasing log-prob curve as the prefix grows.
|
||||
call_log: list[int] = []
|
||||
|
||||
def fake_log_prob(self, frames, instruction): # noqa: ARG002
|
||||
call_log.append(int(frames.shape[0]))
|
||||
return float(frames.shape[0]) # log-prob = prefix length
|
||||
|
||||
monkeypatch.setattr(TOPRewardModel, "_compute_log_prob_reward", fake_log_prob)
|
||||
|
||||
frames = np.zeros((5, 8, 8, 3), dtype=np.uint8)
|
||||
batch = _make_batch([frames], ["lift the cup"])
|
||||
out = model.predict_curves(batch)
|
||||
|
||||
# One forward per prefix length, in order.
|
||||
assert call_log == [1, 2, 3, 4, 5]
|
||||
# (B, T_max) shape, padded with NaN beyond each trajectory's length.
|
||||
assert out["progress"].shape == (1, 5)
|
||||
# Strictly increasing raw rewards -> min-max-normalised to [0, 1] linearly.
|
||||
expected = torch.tensor([[0.0, 0.25, 0.5, 0.75, 1.0]])
|
||||
assert torch.allclose(out["progress"], expected, atol=1e-6)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_predict_curves_sparse_dense_interpolates_to_full_resolution(monkeypatch):
|
||||
"""With ``num_prefixes < N`` the model should score only the requested
|
||||
number of anchor prefixes and linearly interpolate between them — the
|
||||
upstream sparse-dense pattern (``num_samples=15``)."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
call_log: list[int] = []
|
||||
|
||||
def fake_log_prob(self, frames, instruction): # noqa: ARG002
|
||||
call_log.append(int(frames.shape[0]))
|
||||
return float(frames.shape[0])
|
||||
|
||||
monkeypatch.setattr(TOPRewardModel, "_compute_log_prob_reward", fake_log_prob)
|
||||
|
||||
frames = np.zeros((9, 8, 8, 3), dtype=np.uint8)
|
||||
out = model.predict_curves(_make_batch([frames], ["lift the cup"]), num_prefixes=3)
|
||||
|
||||
# 3 anchors at linspace(1, 9, 3) -> [1, 5, 9] -> 3 VLM forwards instead of 9.
|
||||
assert call_log == [1, 5, 9]
|
||||
# Returned curve is full resolution (9 frames) and monotone in [0, 1].
|
||||
assert out["progress"].shape == (1, 9)
|
||||
curve = out["progress"][0].numpy()
|
||||
assert curve[0] == pytest.approx(0.0)
|
||||
assert curve[-1] == pytest.approx(1.0)
|
||||
assert np.all(np.diff(curve) >= 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_predict_curves_rejects_invalid_num_prefixes(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
model = TOPRewardModel(TOPRewardConfig(device="cpu"))
|
||||
batch = _make_batch([np.zeros((3, 8, 8, 3), dtype=np.uint8)], ["task"])
|
||||
with pytest.raises(ValueError, match="num_prefixes must be"):
|
||||
model.predict_curves(batch, num_prefixes=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_predict_curves_right_pads_with_nan_for_variable_lengths(monkeypatch):
|
||||
"""Trajectories of different lengths in the same batch are right-padded
|
||||
with ``NaN`` so the output is a regular ``(B, T_max)`` tensor."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
monkeypatch.setattr(
|
||||
TOPRewardModel,
|
||||
"_compute_log_prob_reward",
|
||||
lambda _self, frames, _instr: float(frames.shape[0]),
|
||||
)
|
||||
|
||||
frames_short = np.zeros((2, 8, 8, 3), dtype=np.uint8)
|
||||
frames_long = np.zeros((4, 8, 8, 3), dtype=np.uint8)
|
||||
out = model.predict_curves(_make_batch([frames_short, frames_long], ["a", "b"]))
|
||||
|
||||
assert out["progress"].shape == (2, 4)
|
||||
# Trailing entries for the shorter trajectory are NaN.
|
||||
assert torch.isnan(out["progress"][0, 2:]).all()
|
||||
# The longer trajectory has no NaNs.
|
||||
assert not torch.isnan(out["progress"][1]).any()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save / load — config-only checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_save_pretrained_writes_only_config_json(monkeypatch, tmp_path):
|
||||
"""A TOPReward "checkpoint" is just ``config.json``. Writing
|
||||
``model.safetensors`` would only duplicate ~16 GB of Qwen weights for
|
||||
no benefit, so :meth:`_save_pretrained` must skip it entirely.
|
||||
"""
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(
|
||||
device="cpu",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
reduction="sum",
|
||||
fps=4.0,
|
||||
image_key="observation.images.front",
|
||||
)
|
||||
model = TOPRewardModel(cfg)
|
||||
model.save_pretrained(str(tmp_path))
|
||||
|
||||
assert (tmp_path / CONFIG_NAME).exists()
|
||||
# Zero-shot model: no safetensors written by `_save_pretrained`.
|
||||
assert not (tmp_path / SAFETENSORS_SINGLE_FILE).exists()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_from_pretrained_local_dir_roundtrips_config(monkeypatch, tmp_path):
|
||||
"""Save a TOPRewardConfig locally and reload it — user knobs must survive."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(
|
||||
device="cpu",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
reduction="sum",
|
||||
fps=4.0,
|
||||
image_key="observation.images.front",
|
||||
use_video_description=True,
|
||||
add_chat_template=True,
|
||||
success_threshold=-1.5,
|
||||
)
|
||||
TOPRewardModel(cfg).save_pretrained(str(tmp_path))
|
||||
|
||||
reloaded = TOPRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
assert isinstance(reloaded.config, TOPRewardConfig)
|
||||
assert reloaded.config.vlm_name == "Qwen/Qwen3-VL-8B-Instruct"
|
||||
assert reloaded.config.reduction == "sum"
|
||||
assert reloaded.config.fps == 4.0
|
||||
assert reloaded.config.image_key == "observation.images.front"
|
||||
assert reloaded.config.use_video_description is True
|
||||
assert reloaded.config.add_chat_template is True
|
||||
assert reloaded.config.success_threshold == -1.5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_is_not_trainable(monkeypatch):
|
||||
"""The whole point of TOPReward is that it is zero-shot.
|
||||
``is_trainable`` must therefore be ``False`` and ``forward(...)`` must
|
||||
raise the base-class ``NotImplementedError``."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
assert model.is_trainable is False
|
||||
with pytest.raises(NotImplementedError, match="not trainable"):
|
||||
model.forward({"x": torch.zeros(1)})
|
||||
@@ -0,0 +1,253 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for TOPReward's pre-processing helpers and encoder step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.rewards.topreward.processor_topreward import (
|
||||
TOPREWARD_FEATURE_PREFIX,
|
||||
TOPRewardEncoderProcessorStep,
|
||||
_expand_tasks,
|
||||
_video_to_numpy,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _video_to_numpy — pure (T, C, H, W) -> (T, H, W, C) uint8 conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_video_to_numpy_chw_float_is_converted_to_thwc_uint8():
|
||||
video = torch.rand(4, 3, 8, 8) # (T, C, H, W) floats in [0, 1]
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert array.shape == (4, 8, 8, 3)
|
||||
assert array.dtype == np.uint8
|
||||
assert array.min() >= 0 and array.max() <= 255
|
||||
|
||||
|
||||
def test_video_to_numpy_already_thwc_uint8_passes_through():
|
||||
video = torch.randint(0, 256, (3, 8, 8, 3), dtype=torch.uint8)
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert array.shape == (3, 8, 8, 3)
|
||||
assert array.dtype == np.uint8
|
||||
|
||||
|
||||
def test_video_to_numpy_max_frames_tail_crops_recent_frames():
|
||||
"""``max_frames`` should keep the **last** K frames (most recent)."""
|
||||
video = torch.zeros(10, 3, 4, 4)
|
||||
for t in range(10):
|
||||
video[t] = t / 9.0
|
||||
|
||||
array = _video_to_numpy(video, max_frames=3)
|
||||
|
||||
assert array.shape == (3, 4, 4, 3)
|
||||
assert int(array[0, 0, 0, 0]) == int(round(7 / 9 * 255))
|
||||
assert int(array[-1, 0, 0, 0]) == 255
|
||||
|
||||
|
||||
def test_video_to_numpy_rejects_3d_input():
|
||||
with pytest.raises(ValueError, match="Expected channel dim"):
|
||||
_video_to_numpy(torch.zeros(4, 8, 8), max_frames=None)
|
||||
|
||||
|
||||
def test_video_to_numpy_floats_above_one_pass_through_without_rescaling():
|
||||
"""If ``array.max() > 1`` the helper assumes the tensor is already in the
|
||||
uint8 range; values pass through unchanged (but are still clipped to 255)."""
|
||||
video = torch.full((1, 3, 2, 2), 5.0)
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert array.shape == (1, 2, 2, 3)
|
||||
assert int(array.max()) == 5
|
||||
|
||||
|
||||
def test_video_to_numpy_clips_very_large_floats_to_uint8_max():
|
||||
video = torch.full((1, 3, 2, 2), 300.0)
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert int(array.max()) == 255
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _expand_tasks — string / list / tuple broadcasting to batch size
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_expand_tasks_string_is_broadcast_to_batch_size():
|
||||
assert _expand_tasks("pick up", batch_size=3, default=None) == ["pick up", "pick up", "pick up"]
|
||||
|
||||
|
||||
def test_expand_tasks_list_of_matching_size_passes_through():
|
||||
assert _expand_tasks(["a", "b", "c"], batch_size=3, default=None) == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_expand_tasks_tuple_is_normalised_to_list():
|
||||
assert _expand_tasks(("a", "b"), batch_size=2, default=None) == ["a", "b"]
|
||||
|
||||
|
||||
def test_expand_tasks_single_element_list_is_broadcast():
|
||||
assert _expand_tasks(["only one"], batch_size=3, default=None) == ["only one"] * 3
|
||||
|
||||
|
||||
def test_expand_tasks_size_mismatch_raises():
|
||||
with pytest.raises(ValueError, match="Expected 3 tasks"):
|
||||
_expand_tasks(["a", "b"], batch_size=3, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_missing_uses_default():
|
||||
assert _expand_tasks(None, batch_size=2, default="fallback") == ["fallback", "fallback"]
|
||||
|
||||
|
||||
def test_expand_tasks_missing_without_default_raises():
|
||||
with pytest.raises(KeyError, match="task description"):
|
||||
_expand_tasks(None, batch_size=1, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_wrong_type_raises():
|
||||
with pytest.raises(TypeError, match="must be a string or list"):
|
||||
_expand_tasks(42, batch_size=1, default=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder step — input/output shapes + dataclass surface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_transition(observation: dict, complementary: dict | None = None) -> dict:
|
||||
"""Build a tiny ``EnvTransition`` dict for the encoder step."""
|
||||
transition: dict = {TransitionKey.OBSERVATION: observation}
|
||||
if complementary is not None:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary
|
||||
return transition
|
||||
|
||||
|
||||
def test_encoder_step_writes_namespaced_frames_and_task():
|
||||
"""The encoder step's output is the contract the model reads from. It
|
||||
must populate exactly two namespaced keys: ``frames`` and ``task``."""
|
||||
step = TOPRewardEncoderProcessorStep(
|
||||
image_key="observation.images.top",
|
||||
task_key="task",
|
||||
max_frames=None,
|
||||
)
|
||||
|
||||
frames_batch = torch.zeros(2, 4, 3, 8, 8) # (B=2, T=4, C, H, W)
|
||||
out = step(
|
||||
_make_transition(
|
||||
observation={"observation.images.top": frames_batch},
|
||||
complementary={"task": ["pick", "place"]},
|
||||
)
|
||||
)
|
||||
|
||||
obs_out = out[TransitionKey.OBSERVATION]
|
||||
frames_out = obs_out[f"{TOPREWARD_FEATURE_PREFIX}frames"]
|
||||
tasks_out = obs_out[f"{TOPREWARD_FEATURE_PREFIX}task"]
|
||||
|
||||
assert len(frames_out) == 2
|
||||
assert all(arr.shape == (4, 8, 8, 3) and arr.dtype == np.uint8 for arr in frames_out)
|
||||
assert tasks_out == ["pick", "place"]
|
||||
|
||||
|
||||
def test_encoder_step_adds_singleton_time_dim_for_4d_input():
|
||||
"""A ``(B, C, H, W)`` observation is the single-frame case; the encoder
|
||||
must unsqueeze the time dim so the model still sees a video."""
|
||||
step = TOPRewardEncoderProcessorStep(image_key="observation.images.top", max_frames=None)
|
||||
|
||||
frames_batch = torch.zeros(1, 3, 8, 8) # (B=1, C, H, W) — no time dim
|
||||
out = step(
|
||||
_make_transition(
|
||||
observation={"observation.images.top": frames_batch},
|
||||
complementary={"task": "pick"},
|
||||
)
|
||||
)
|
||||
|
||||
frames_out = out[TransitionKey.OBSERVATION][f"{TOPREWARD_FEATURE_PREFIX}frames"]
|
||||
assert len(frames_out) == 1
|
||||
assert frames_out[0].shape == (1, 8, 8, 3) # (T=1, H, W, C)
|
||||
|
||||
|
||||
def test_encoder_step_uses_default_task_when_complementary_is_missing():
|
||||
step = TOPRewardEncoderProcessorStep(
|
||||
image_key="observation.images.top",
|
||||
default_task="perform the task",
|
||||
)
|
||||
|
||||
frames_batch = torch.zeros(1, 2, 3, 4, 4)
|
||||
out = step(_make_transition(observation={"observation.images.top": frames_batch}))
|
||||
|
||||
tasks_out = out[TransitionKey.OBSERVATION][f"{TOPREWARD_FEATURE_PREFIX}task"]
|
||||
assert tasks_out == ["perform the task"]
|
||||
|
||||
|
||||
def test_encoder_step_rejects_missing_image_key():
|
||||
step = TOPRewardEncoderProcessorStep(image_key="observation.images.top")
|
||||
with pytest.raises(KeyError, match="image key"):
|
||||
step(_make_transition(observation={}, complementary={"task": "pick"}))
|
||||
|
||||
|
||||
def test_encoder_step_rejects_non_dict_observation():
|
||||
step = TOPRewardEncoderProcessorStep()
|
||||
with pytest.raises(ValueError, match="observation dict"):
|
||||
step({TransitionKey.OBSERVATION: torch.zeros(1, 3, 8, 8)})
|
||||
|
||||
|
||||
def test_encoder_step_rejects_3d_or_6d_input():
|
||||
"""The encoder accepts ``(B,C,H,W)`` or ``(B,T,C,H,W)`` only."""
|
||||
step = TOPRewardEncoderProcessorStep(image_key="observation.images.top")
|
||||
with pytest.raises(ValueError, match=r"\(B,C,H,W\)"):
|
||||
step(
|
||||
_make_transition(
|
||||
observation={"observation.images.top": torch.zeros(8, 8, 3)},
|
||||
complementary={"task": "pick"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_encoder_step_get_config_roundtrips_user_fields():
|
||||
"""``get_config`` must serialise every user-tunable field — these are
|
||||
what the processor pipeline saves under ``preprocessor_config.json``."""
|
||||
step = TOPRewardEncoderProcessorStep(
|
||||
image_key="observation.images.cam_top",
|
||||
task_key="task",
|
||||
default_task="do the thing",
|
||||
max_frames=8,
|
||||
)
|
||||
|
||||
assert step.get_config() == {
|
||||
"image_key": "observation.images.cam_top",
|
||||
"task_key": "task",
|
||||
"default_task": "do the thing",
|
||||
"max_frames": 8,
|
||||
}
|
||||
|
||||
|
||||
def test_encoder_step_transform_features_is_identity():
|
||||
"""The encoder writes plain Python objects (numpy arrays / strings)
|
||||
into ``observation`` at call time but does NOT advertise new typed
|
||||
features at pipeline-build time — the model reads them via the
|
||||
``TOPREWARD_FEATURE_PREFIX`` namespace, not via the typed feature map.
|
||||
"""
|
||||
step = TOPRewardEncoderProcessorStep()
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.images.top": PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL),
|
||||
}
|
||||
}
|
||||
assert step.transform_features(features) == features
|
||||
@@ -3009,6 +3009,10 @@ test = [
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-timeout" },
|
||||
]
|
||||
topreward = [
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
training = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "av" },
|
||||
@@ -3154,6 +3158,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'topreward'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
||||
@@ -3177,6 +3182,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
||||
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
||||
@@ -3244,7 +3250,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user