From d1d218a56c05d18c993c2bf587423a2e77ecfdc1 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 18 May 2026 17:17:29 +0200 Subject: [PATCH] feat/add ROBOMETER reward model --- docs/source/_toctree.yml | 2 + docs/source/robometer.mdx | 185 +++++++ pyproject.toml | 1 + src/lerobot/rewards/__init__.py | 2 + src/lerobot/rewards/factory.py | 18 +- src/lerobot/rewards/robometer/__init__.py | 19 + .../robometer/configuration_robometer.py | 158 ++++++ .../rewards/robometer/modeling_robometer.py | 480 ++++++++++++++++++ .../rewards/robometer/processor_robometer.py | 338 ++++++++++++ .../lerobot_rewardmodel_modelcard_template.md | 2 + tests/rewards/test_modeling_robometer.py | 340 +++++++++++++ tests/rewards/test_robometer_processor.py | 354 +++++++++++++ uv.lock | 14 +- 13 files changed, 1908 insertions(+), 5 deletions(-) create mode 100644 docs/source/robometer.mdx create mode 100644 src/lerobot/rewards/robometer/__init__.py create mode 100644 src/lerobot/rewards/robometer/configuration_robometer.py create mode 100644 src/lerobot/rewards/robometer/modeling_robometer.py create mode 100644 src/lerobot/rewards/robometer/processor_robometer.py create mode 100644 tests/rewards/test_modeling_robometer.py create mode 100644 tests/rewards/test_robometer_processor.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f1dfe9aae..786d92ad5 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -71,6 +71,8 @@ - sections: - local: sarm title: SARM + - local: robometer + title: ROBOMETER title: "Reward Models" - sections: - local: inference diff --git a/docs/source/robometer.mdx b/docs/source/robometer.mdx new file mode 100644 index 000000000..3af822588 --- /dev/null +++ b/docs/source/robometer.mdx @@ -0,0 +1,185 @@ +# ROBOMETER + +ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description. + +**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115) +**Project**: [robometer.github.io](https://robometer.github.io/) +**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer) +**Checkpoint**: [robometer/Robometer-4B](https://huggingface.co/robometer/Robometer-4B) + +## Overview + +ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads: + +- **Progress head**: predicts per-frame task progress in `[0, 1]`. +- **Success head**: predicts per-frame task success probability. +- **Preference head**: predicts which of two trajectories better completes the task during training. + +The paper trains ROBOMETER with a composite objective: + +```text +L = L_pref + L_prog + L_succ +``` + +The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only. + +## What the LeRobot Integration Covers + +- Standard `reward_model.type=robometer` configuration through LeRobot. +- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`. +- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`. +- Dense, frame-level progress and success predictions internally. +- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage. + +This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration. + +## Installation Requirements + +1. Install LeRobot by following the [Installation Guide](./installation). +2. Install the ROBOMETER dependencies: + +```bash +pip install -e ".[robometer]" +``` + +If you use `uv` directly from a source checkout: + +```bash +uv sync --extra robometer +``` + +ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended. + +## Model Inputs and Outputs + +ROBOMETER expects: + +- A trajectory video or sequence of frames. +- A natural-language task description. + +In LeRobot datasets, the preprocessor reads: + +| Config field | Default | Meaning | +| ------------------------- | ------------------------ | ----------------------------------------------------- | +| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER | +| `reward_model.task_key` | `task` | Key in complementary data that stores the task string | +| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER | + +The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample: + +- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`. +- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`. + +## Usage + +### Load the Reward Model Directly + +```python +from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel + +cfg = RobometerConfig( + pretrained_path="lilkm/Robometer-4B", + device="cuda", + reward_output="progress", +) +reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg) +``` + +### Encode Frames and Compute a Reward + +For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string: + +```python +from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX +from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep + +# frames: np.ndarray, shape (T, H, W, C), dtype uint8 +# task: str +encoder = RobometerEncoderProcessorStep( + base_model_id=cfg.base_model_id, + use_multi_image=cfg.use_multi_image, + use_per_frame_progress_token=cfg.use_per_frame_progress_token, + max_frames=cfg.max_frames, +) + +encoded = encoder.encode_samples([(frames, task)]) +batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()} + +reward = reward_model.compute_reward(batch) +``` + +`reward` is a tensor of shape `(batch_size,)`. + +### Use the Reward Factory + +You can also instantiate ROBOMETER through the reward factory: + +```python +from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors + +cfg = make_reward_model_config( + "robometer", + pretrained_path="lilkm/Robometer-4B", + device="cuda", + image_key="observation.images.top", +) +reward_model = make_reward_model(cfg) +preprocessor, postprocessor = make_reward_pre_post_processors(cfg) +``` + +The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors. + +## Configuration Notes + +### Backbone and Vocabulary + +The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order: + +```text +<|split_token|> +<|reward_token|> +<|pref_token|> +<|sim_token|> +<|prog_token|> +``` + +`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility. + +The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`. + +### Progress Prediction + +In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation. + +### Success Prediction + +The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`. + +## Limitations + +- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training. +- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally. +- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses. + +## References + +- [ROBOMETER project](https://robometer.github.io/) +- [ROBOMETER paper](https://arxiv.org/abs/2603.02115) +- [Original ROBOMETER code](https://github.com/aliang8/robometer) +- [Published ROBOMETER-4B checkpoint](https://huggingface.co/robometer/Robometer-4B) +- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct) + +## Citation + +```bibtex +@article{liang2026robometer, + title={ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons}, + author={Liang, Anthony and Korkmaz, Yigit and Zhang, Jiahui and Hwang, Minyoung and Anwar, Abrar and Kaushik, Sidhant and Shah, Aditya and Huang, Alex S. and Zettlemoyer, Luke and Fox, Dieter and Xiang, Yu and Li, Anqi and Bobu, Andreea and Gupta, Abhishek and Tu, Stephen and Biyik, Erdem and Zhang, Jesse}, + journal={arXiv preprint arXiv:2603.02115}, + year={2026} +} +``` + +## License + +This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints. diff --git a/pyproject.toml b/pyproject.toml index f983134ab..c0703ae59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,6 +204,7 @@ groot = [ "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] +robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] diff --git a/src/lerobot/rewards/__init__.py b/src/lerobot/rewards/__init__.py index 203fe2ee1..a6a98c3c6 100644 --- a/src/lerobot/rewards/__init__.py +++ b/src/lerobot/rewards/__init__.py @@ -20,11 +20,13 @@ from .factory import ( make_reward_pre_post_processors as make_reward_pre_post_processors, ) from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel +from .robometer.configuration_robometer import RobometerConfig as RobometerConfig from .sarm.configuration_sarm import SARMConfig as SARMConfig __all__ = [ # Configuration classes "RewardClassifierConfig", + "RobometerConfig", "SARMConfig", # Base class "PreTrainedRewardModel", diff --git a/src/lerobot/rewards/factory.py b/src/lerobot/rewards/factory.py index c173f44a5..38a269cb9 100644 --- a/src/lerobot/rewards/factory.py +++ b/src/lerobot/rewards/factory.py @@ -25,6 +25,7 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline from .classifier.configuration_classifier import RewardClassifierConfig from .pretrained import PreTrainedRewardModel +from .robometer.configuration_robometer import RobometerConfig from .sarm.configuration_sarm import SARMConfig @@ -37,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]: Args: name: The name of the reward model. Supported names are "reward_classifier", - "sarm". + "sarm", "robometer". Returns: The reward model class corresponding to the given name. @@ -53,6 +54,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]: from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel return SARMRewardModel + elif name == "robometer": + from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel + + return RobometerRewardModel else: try: return _get_reward_model_cls_from_name(name=name) @@ -69,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig: Args: reward_type: The type of the reward model. Supported types include - "reward_classifier", "sarm". + "reward_classifier", "sarm", "robometer". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -82,6 +87,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig: return RewardClassifierConfig(**kwargs) elif reward_type == "sarm": return SARMConfig(**kwargs) + elif reward_type == "robometer": + return RobometerConfig(**kwargs) else: try: config_cls = RewardModelConfig.get_choice_class(reward_type) @@ -161,6 +168,13 @@ def make_reward_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), dataset_meta=kwargs.get("dataset_meta"), ) + elif isinstance(reward_cfg, RobometerConfig): + from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors + + return make_robometer_pre_post_processors( + config=reward_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) else: try: diff --git a/src/lerobot/rewards/robometer/__init__.py b/src/lerobot/rewards/robometer/__init__.py new file mode 100644 index 000000000..d20d92d37 --- /dev/null +++ b/src/lerobot/rewards/robometer/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_robometer import RobometerConfig +from .modeling_robometer import RobometerRewardModel +from .processor_robometer import make_robometer_pre_post_processors + +__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"] diff --git a/src/lerobot/rewards/robometer/configuration_robometer.py b/src/lerobot/rewards/robometer/configuration_robometer.py new file mode 100644 index 000000000..063c4a9be --- /dev/null +++ b/src/lerobot/rewards/robometer/configuration_robometer.py @@ -0,0 +1,158 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.rewards import RewardModelConfig +from lerobot.utils.constants import OBS_IMAGES +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoConfig, AutoTokenizer +else: + AutoConfig = None # type: ignore[assignment] + AutoTokenizer = None # type: ignore[assignment] + + +# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time. +# The order is part of the data contract: upstream resized ``embed_tokens`` +# after adding these tokens in this exact order, so changing the set or order +# would silently misalign the saved embedding rows with their token ids. +# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream +# heads (never read at inference) but still occupy rows the checkpoint expects. +ROBOMETER_SPECIAL_TOKENS = ( + "<|split_token|>", + "<|reward_token|>", + "<|pref_token|>", + "<|sim_token|>", + "<|prog_token|>", +) + + +@RewardModelConfig.register_subclass("robometer") +@dataclass +class RobometerConfig(RewardModelConfig): + """Configuration for the Robometer reward model.""" + + pretrained_path: str | None = "lilkm/Robometer-4B" + image_key: str = OBS_IMAGES + ".top" + task_key: str = "task" + default_task: str | None = None + + max_frames: int | None = 8 + reward_output: str = "progress" # "progress" or "success" + success_threshold: float = 0.5 + + license: str | None = "apache-2.0" + tags: list[str] | None = field( + default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"] + ) + + base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct" + torch_dtype: str = "bfloat16" + use_multi_image: bool = True + use_per_frame_progress_token: bool = True + average_temporal_patches: bool = True + frame_pooling: str = "mean" # "mean" | "boundary" | "attention" + frame_pooling_attn_temperature: float = 1.0 + progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete" + progress_discrete_bins: int = 10 + + # Serialised Qwen backbone config (post-resize). Always populated by + # ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it + # is non-empty after construction. Saved into ``config.json`` automatically + # by the base ``_save_pretrained``. + vlm_config: dict[str, Any] = field(default_factory=dict) + + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + ) + + def __post_init__(self) -> None: + super().__post_init__() + if self.reward_output not in {"progress", "success"}: + raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}") + if self.max_frames is not None and self.max_frames < 1: + raise ValueError(f"max_frames must be >= 1, got {self.max_frames}") + if self.frame_pooling not in {"mean", "boundary", "attention"}: + raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}") + if self.frame_pooling_attn_temperature <= 0: + raise ValueError("frame_pooling_attn_temperature must be > 0") + if self.progress_loss_type not in {"l1", "l2", "discrete"}: + raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}") + if self.use_per_frame_progress_token and not self.use_multi_image: + raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True") + + if self.image_key not in self.input_features: + self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL) + self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD)) + self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD)) + + # Deterministically populate ``vlm_config`` so it is non-empty after + # construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives + # ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize + # vocab the published ``Robometer-4B`` checkpoint was saved with. + if not self.vlm_config: + require_package("transformers", extra="robometer") + vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict() + tokenizer = AutoTokenizer.from_pretrained(self.base_model_id) + text_config = vlm.get("text_config") + if not isinstance(text_config, dict): + raise ValueError( + f"Backbone config for {self.base_model_id!r} has no nested `text_config`; " + "Robometer expects a Qwen-VL-style config." + ) + text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS) + self.vlm_config = vlm + + @property + def use_discrete_progress(self) -> bool: + """Whether the progress head outputs distribution logits over bins.""" + return self.progress_loss_type.lower() == "discrete" + + @property + def vlm_backbone_config(self): + """Reconstruct the Qwen backbone config from :attr:`vlm_config`.""" + require_package("transformers", extra="robometer") + config_dict = deepcopy(self.vlm_config) + model_type = config_dict.pop("model_type", None) + if model_type is None: + raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config") + return AutoConfig.for_model(model_type, **config_dict) + + @property + def observation_delta_indices(self) -> list[int] | None: + return None + + @property + def action_delta_indices(self) -> None: + return None + + @property + def reward_delta_indices(self) -> None: + return None + + def validate_features(self) -> None: + if self.image_key not in self.input_features: + raise ValueError(f"Robometer requires image input feature {self.image_key!r}") diff --git a/src/lerobot/rewards/robometer/modeling_robometer.py b/src/lerobot/rewards/robometer/modeling_robometer.py new file mode 100644 index 000000000..6f462de85 --- /dev/null +++ b/src/lerobot/rewards/robometer/modeling_robometer.py @@ -0,0 +1,480 @@ +# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons. + +Paper: https://arxiv.org/abs/2603.02115 +Project: https://robometer.github.io +Original code: https://github.com/aliang8/robometer +Model: https://huggingface.co/robometer/Robometer-4B + +Robometer is a general-purpose, video-language-input reward model built on +``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction +objective: + +- A frame-level progress loss anchoring reward magnitude on expert data. +- A trajectory-comparison preference loss imposing global ordering constraints + across trajectories sharing the same instruction. + +To support downstream RL it also predicts a frame-level binary success. The +training prompt inserts three learnable tokens: + +- ``<|prog_token|>`` after each frame to read per-frame progress and success. +- ``<|pref_token|>`` at the end to read pairwise preference (training-only). +- ``<|split_token|>`` between two trajectories in preference samples + (training-only). + +Progress is modeled as a categorical distribution over ``progress_discrete_bins`` +uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate +is recovered as the softmax-weighted mean of those centers — see +:func:`convert_bins_to_continuous`. + +This LeRobot port is **inference-only**: the preference head is preserved in +the state dict for byte-equivalence with the published ``Robometer-4B`` +checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`, +which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd +success probability depending on :attr:`RobometerConfig.reward_output`. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import torch +from torch import Tensor, nn + +from lerobot.rewards.pretrained import PreTrainedRewardModel +from lerobot.rewards.robometer.configuration_robometer import RobometerConfig +from lerobot.utils.constants import OBS_PREFIX +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoModelForImageTextToText +else: + AutoModelForImageTextToText = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +# Namespace for Robometer's pre-encoded Qwen-VL observation tensors. +ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer." +ROBOMETER_QWEN_INPUT_KEYS = ( + "input_ids", + "attention_mask", + "pixel_values", + "pixel_values_videos", + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", +) +ROBOMETER_METADATA_KEYS = ( + "prog_token_id", + "vision_start_token_id", + "vision_end_token_id", + "video_merge_size", +) +ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS + + +def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor: + """Collapse per-bin logits into a single value in ``[0, 1]``. + + The discrete progress head outputs ``num_bins`` logits per frame. Bins are + evenly spaced centers in ``[0, 1]``; the continuous prediction is the + softmax-weighted mean of those centers. + """ + bin_probs = torch.softmax(bin_logits, dim=-1) + num_bins = bin_logits.shape[-1] + bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype) + return (bin_probs * bin_centers).sum(dim=-1) + + +def _squeeze_last_safe(x: Tensor) -> Tensor: + """Drop a trailing singleton dim only when present.""" + return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x + + +def _torch_dtype(name: str) -> torch.dtype: + dtype = getattr(torch, name, None) + if isinstance(dtype, torch.dtype): + return dtype + raise ValueError(f"Unknown torch dtype: {name!r}") + + +class RobometerPredictionHead(nn.Sequential): + """Small MLP head used for Robometer's progress / success / preference outputs.""" + + def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None: + layers: list[nn.Module] = [ + nn.Linear(hidden_dim, hidden_dim // 2), + nn.LayerNorm(hidden_dim // 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, output_size), + ] + if with_sigmoid: + layers.append(nn.Sigmoid()) + super().__init__(*layers) + + +def decode_progress_outputs( + progress_logits: Tensor | None, + success_logits: Tensor | None, + *, + is_discrete_mode: bool, +) -> dict[str, list[list[float]]]: + """Decode RBM head outputs into per-frame floats. + + Args: + progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete). + success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities. + is_discrete_mode: if True the progress logits get a softmax over bins + and are projected onto bin centers via :func:`convert_bins_to_continuous`. + + Returns: + Dict with ``progress_pred`` and ``success_probs``, each a list of + length ``B`` of per-frame float lists. + """ + progress_pred: list[list[float]] = [] + success_probs: list[list[float]] = [] + + if progress_logits is not None: + for sample_logits in progress_logits: + if is_discrete_mode: + continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu()) + progress_pred.append(continuous.flatten().tolist()) + else: + progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist()) + + if success_logits is not None: + for sample_logits in success_logits: + success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist()) + + return {"progress_pred": progress_pred, "success_probs": success_probs} + + +class RobometerRewardModel(PreTrainedRewardModel): + """Robometer (RBM) reward model — inference-only LeRobot port. + + Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three + prediction heads from the paper (progress, success, preference). At + inference time only the progress and success heads are queried; the + preference head is kept on the module so the published ``Robometer-4B`` + safetensors load unchanged. + """ + + name = "robometer" + config_class = RobometerConfig + + def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None: + require_package("transformers", extra="robometer") + super().__init__(config) + self.config = config + + # Two backbone-build paths (EO-1 style, branched on ``pretrained_path``): + # + # - Fresh training (``pretrained_path is None``): download the base + # Qwen weights and resize the embed table to match + # ``vlm_config.text_config.vocab_size`` — populated deterministically + # in ``RobometerConfig.__post_init__`` as + # ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)`` + # + # - Loading a saved checkpoint (``pretrained_path`` is set): rebuild + # the empty architecture from ``vlm_config`` via + # ``AutoModelForImageTextToText.from_config`` so the subsequent + # ``model.safetensors`` load is a direct fill of the right shape — + # no redundant Qwen weight download. + torch_dtype = _torch_dtype(config.torch_dtype) + if config.pretrained_path is None: + self.model = AutoModelForImageTextToText.from_pretrained( + config.base_model_id, + dtype=torch_dtype, + trust_remote_code=True, + ) + target_vocab = config.vlm_config["text_config"]["vocab_size"] + self.model.resize_token_embeddings(target_vocab) + else: + self.model = AutoModelForImageTextToText.from_config( + config.vlm_backbone_config, + dtype=torch_dtype, + trust_remote_code=True, + ) + + # All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`. + # Falls back to the top-level `hidden_size` so future non-multimodal + # variants would still resolve. + backbone_config = self.model.config + text_config = getattr(backbone_config, "text_config", None) + hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None + if hidden_size is None: + hidden_size = getattr(backbone_config, "hidden_size", None) + if hidden_size is None: + raise AttributeError( + f"Could not infer hidden_size from backbone config of {config.base_model_id}" + ) + hidden_dim = int(hidden_size) + + # Robometer's three prediction heads + frame-pool attention. + progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1 + self.progress_head = RobometerPredictionHead( + hidden_dim, + progress_output, + dropout=dropout, + with_sigmoid=not config.use_discrete_progress, + ) + self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False) + self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False) + self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False) + + # Match the dtype of the loaded base model so weight loading is a no-op cast. + model_dtype = next(self.model.parameters()).dtype + self.progress_head.to(dtype=model_dtype) + self.preference_head.to(dtype=model_dtype) + self.success_head.to(dtype=model_dtype) + self.frame_pool_attn.to(dtype=model_dtype) + + def compute_reward(self, batch: dict[str, Tensor]) -> Tensor: + inputs = { + key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"] + for key in ROBOMETER_INPUT_KEYS + if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch + } + if "input_ids" not in inputs: + raise KeyError( + f"Robometer batch missing pre-encoded inputs (expected " + f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the " + "RobometerEncoderProcessorStep ran before `compute_reward`." + ) + + device = next(self.model.parameters()).device + inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()} + + self.eval() + with torch.no_grad(): + progress_logits, success_logits = self._compute_rbm_logits(inputs) + + decoded = decode_progress_outputs( + progress_logits, + success_logits, + is_discrete_mode=self.config.use_discrete_progress, + ) + values = ( + decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"] + ) + + rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values]) + if self.config.reward_output == "success": + rewards = (rewards > self.config.success_threshold).float() + else: + # Match upstream Robometer's ``extract_rewards_from_output``: per-frame + # progress predictions are clamped to ``[0, 1]`` before being returned. + rewards = rewards.clamp(0.0, 1.0) + return rewards.to(self.config.device or "cpu") + + def _compute_rbm_logits( + self, + inputs: dict[str, Any], + ) -> tuple[Tensor, Tensor]: + """Run the Qwen3-VL backbone and apply Robometer's heads. + + ``inputs`` is the encoded batch produced by + :class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well + as Robometer-specific metadata (``prog_token_id``, + ``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``) + — the metadata is popped here so the rest can be forwarded straight to + the Qwen model. + + Returns ``(progress_logits, success_logits)``. Shapes: + + - ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete). + - ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time). + """ + prog_token_id = inputs.pop("prog_token_id", None) + vision_start_token_id = inputs.pop("vision_start_token_id", None) + vision_end_token_id = inputs.pop("vision_end_token_id", None) + video_merge_size = inputs.pop("video_merge_size", 14) + + # Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the + # full hidden-state tuple and take the last layer. This matches the + # `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main). + outputs = self.model(**inputs, output_hidden_states=True, return_dict=True) + hidden_state = ( + outputs.hidden_states[-1] + if getattr(outputs, "hidden_states", None) + else outputs.last_hidden_state + ) + + input_ids = inputs["input_ids"] + if self.config.use_per_frame_progress_token: + if prog_token_id is None: + raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)") + return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id) + if self.config.use_multi_image: + if vision_start_token_id is None or vision_end_token_id is None: + raise KeyError( + "`vision_start_token_id` / `vision_end_token_id` missing in batch " + "(run RobometerEncoderProcessorStep first)" + ) + return self._process_multi_image_frames( + hidden_state, + input_ids, + start_id=vision_start_token_id, + end_id=vision_end_token_id, + ) + video_grid_thw = inputs.get("video_grid_thw") + if video_grid_thw is None: + raise ValueError("video_grid_thw is required for video-mode Robometer inference") + if vision_start_token_id is None: + raise KeyError("`vision_start_token_id` missing in batch") + return self._process_video_frames( + hidden_state, + input_ids, + video_grid_thw, + start_id=vision_start_token_id, + merge_size=video_merge_size, + ) + + def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]: + """Apply progress + success heads to a tensor of frame embeddings.""" + progress_out = self.progress_head(frame_embeddings) + progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out) + success = _squeeze_last_safe(self.success_head(frame_embeddings)) + return progress, success + + def _process_token_extraction( + self, + hidden_state: Tensor, + input_ids: Tensor, + *, + prog_token_id: int, + ) -> tuple[Tensor, Tensor]: + """Per-frame progress/success from ``<|prog_token|>`` positions.""" + token_mask = input_ids == prog_token_id + batch_indices, positions = token_mask.nonzero(as_tuple=True) + if positions.numel() == 0: + raise ValueError("`<|prog_token|>` not found in any sequence") + + per_sample_hidden = [ + hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0]) + ] + progress_list, success_list = [], [] + for embeddings in per_sample_hidden: + if embeddings.shape[0] == 0: + raise ValueError("`<|prog_token|>` missing in a sequence") + progress, success = self._apply_heads_to_hidden_states(embeddings) + progress_list.append(progress) + success_list.append(success) + + return torch.stack(progress_list), torch.stack(success_list) + + def _process_multi_image_frames( + self, + hidden_state: Tensor, + input_ids: Tensor, + *, + start_id: int, + end_id: int, + ) -> tuple[Tensor, Tensor]: + """Per-frame progress/success in multi-image mode (Qwen-VL).""" + progress_list, success_list = [], [] + for batch_idx in range(input_ids.shape[0]): + seq_ids = input_ids[batch_idx] + seq_hidden = hidden_state[batch_idx] + frame_embeddings = self._extract_hidden_states_from_token_pairs( + seq_hidden, seq_ids, start_id, end_id + ) + progress, success = self._apply_heads_to_hidden_states(frame_embeddings) + progress_list.append(progress) + success_list.append(success) + + return torch.stack(progress_list), torch.stack(success_list) + + def _extract_hidden_states_from_token_pairs( + self, + hidden_state: Tensor, + input_ids: Tensor, + start_id: int, + end_id: int, + ) -> Tensor: + start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0] + end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0] + if start_positions.numel() == 0: + raise ValueError("`<|vision_start|>` not found in sequence") + if start_positions.numel() != end_positions.numel(): + raise ValueError( + f"Mismatched vision token counts: {start_positions.numel()} start vs " + f"{end_positions.numel()} end" + ) + + frames: list[Tensor] = [] + for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True): + if start >= end: + raise ValueError(f"Invalid vision token pair: start={start} end={end}") + patch_tokens = hidden_state[start + 1 : end] + if patch_tokens.shape[0] == 0: + frames.append((hidden_state[start] + hidden_state[end]) / 2.0) + continue + + pooling = self.config.frame_pooling + if pooling == "mean": + frames.append(patch_tokens.mean(dim=0)) + elif pooling == "boundary": + frames.append(patch_tokens[-1]) + else: # attention + scores = ( + self.frame_pool_attn(patch_tokens).squeeze(-1) + / self.config.frame_pooling_attn_temperature + ) + weights = torch.softmax(scores, dim=0).unsqueeze(-1) + frames.append((weights * patch_tokens).sum(dim=0)) + + return torch.stack(frames) + + def _process_video_frames( + self, + hidden_state: Tensor, + input_ids: Tensor, + video_grid_thw: Tensor, + *, + start_id: int, + merge_size: int, + ) -> tuple[Tensor, Tensor]: + """Per-frame progress/success in video mode (Qwen-VL).""" + progress_list, success_list = [], [] + for batch_idx in range(input_ids.shape[0]): + seq_ids = input_ids[batch_idx] + seq_hidden = hidden_state[batch_idx] + start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0] + if start_positions.numel() == 0: + raise ValueError("`<|vision_start|>` not found in sequence") + t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist()) + tokens_per_frame = (h_dim * w_dim) // (merge_size**2) + + cursor = start_positions[0].item() + frame_embeddings: list[Tensor] = [] + for _ in range(t_dim): + if self.config.average_temporal_patches: + patch = seq_hidden[cursor : cursor + tokens_per_frame] + frame_embeddings.append(patch.mean(dim=0)) + else: + frame_embeddings.append(seq_hidden[cursor + tokens_per_frame]) + cursor += tokens_per_frame + + stacked = torch.stack(frame_embeddings) + progress, success = self._apply_heads_to_hidden_states(stacked) + progress_list.append(progress) + success_list.append(success) + + return torch.stack(progress_list), torch.stack(success_list) diff --git a/src/lerobot/rewards/robometer/processor_robometer.py b/src/lerobot/rewards/robometer/processor_robometer.py new file mode 100644 index 000000000..d98f8b9aa --- /dev/null +++ b/src/lerobot/rewards/robometer/processor_robometer.py @@ -0,0 +1,338 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Robometer pre/post processing pipelines.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + policy_action_to_transition, +) +from lerobot.rewards.robometer.configuration_robometer import ( + ROBOMETER_SPECIAL_TOKENS, + RobometerConfig, +) +from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX +from lerobot.types import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_IMAGES, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoProcessor +else: + AutoProcessor = None + +PROGRESS_PROMPT = ( + "The task for the robot is '{task}'. Given the trajectory video, predict " + "the task progress at each frame, how far along the robot is towards " + "completing the task, a float between 0 and 1, where 0 is the starting " + "state and 1 is when the task is completed. If the robot is not " + "performing the same task, predict 0 progress." +) + + +def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]: + """Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images.""" + if frames.ndim != 4: + raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}") + if frames.dtype != np.uint8: + frames = np.clip(frames, 0, 255).astype(np.uint8) + return [Image.fromarray(frames[i]) for i in range(frames.shape[0])] + + +def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray: + """Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array.""" + if max_frames is not None: + video = video[-max_frames:] + if video.shape[1] in (1, 3): + video = video.permute(0, 2, 3, 1) + elif video.shape[-1] not in (1, 3): + raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}") + + array = video.detach().cpu().numpy() + if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0: + array = array * 255.0 + return np.clip(array, 0, 255).astype(np.uint8) + + +def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]: + if task is None: + task = default + if task is None: + raise KeyError("Robometer expected a task description in complementary data") + if isinstance(task, str): + return [task] * batch_size + if isinstance(task, tuple): + task = list(task) + if not (isinstance(task, list) and all(isinstance(item, str) for item in task)): + raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}") + if len(task) == 1 and batch_size > 1: + return task * batch_size + if len(task) != batch_size: + raise ValueError(f"Expected {batch_size} tasks, got {len(task)}") + return task + + +@dataclass +@ProcessorStepRegistry.register(name="robometer_encoder") +class RobometerEncoderProcessorStep(ProcessorStep): + """Encode raw frames + task into Qwen-VL tensors for the Robometer model. + + Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and + registers Robometer's special tokens on the tokenizer. The matching + embedding resize happens model-side in + :meth:`RobometerRewardModel.__init__`. + + At call time the step reads: + + - ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames. + - ``complementary_data[task_key]``: a string or list of strings. + + and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}"]`` for: + + - the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``, + ``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ... + - Robometer-specific token ids consumed by the model heads: + ``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``, + ``video_merge_size``. + """ + + base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct" + image_key: str = OBS_IMAGES + ".top" + task_key: str = "task" + default_task: str | None = None + max_frames: int | None = 8 + use_multi_image: bool = True + use_per_frame_progress_token: bool = True + max_length: int = 1024 + + _processor: Any = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + require_package("transformers", extra="robometer") + require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils") + + self._processor = AutoProcessor.from_pretrained( + self.base_model_id, + trust_remote_code=True, + do_sample_frames=False, + padding_side="right", + ) + + # Register Robometer's special tokens on the tokenizer. The matching + # embedding resize happens model-side in `RobometerRewardModel.__init__`. + tokenizer = self._processor.tokenizer + # Qwen tokenizers may not define a pad token, but batched prompts/videos + # require padding, so reuse EOS as the padding token. + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + for token in ROBOMETER_SPECIAL_TOKENS: + if token not in tokenizer.get_vocab(): + tokenizer.add_special_tokens({"additional_special_tokens": [token]}) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + if not isinstance(observation, dict): + raise ValueError("RobometerEncoderProcessorStep requires an observation dict") + + if self.image_key not in observation: + raise KeyError(f"Robometer expected image key {self.image_key!r} in observation") + + frames = observation[self.image_key] + tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames) + if tensor.ndim == 4: + tensor = tensor.unsqueeze(1) + elif tensor.ndim != 5: + raise ValueError( + f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}" + ) + + batch_size = tensor.shape[0] + tasks = _expand_tasks( + complementary.get(self.task_key, self.default_task), + batch_size=batch_size, + default=self.default_task, + ) + + samples = [ + (_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size) + ] + encoded = self.encode_samples(samples) + + new_observation = dict(observation) + for key, value in encoded.items(): + new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value + + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = new_observation + return new_transition + + def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]: + """Run the Qwen-VL processor on a list of ``(frames, task)`` samples.""" + from qwen_vl_utils import process_vision_info + + conversations = [self._build_conversation(frames, task) for frames, task in samples] + + texts = [ + self._processor.apply_chat_template( + msg, + tokenize=False, + add_generation_prompt=False, + add_vision_id=True, + enable_thinking=False, + fps=1, + ) + for msg in conversations + ] + + process_kwargs: dict[str, Any] = { + "return_video_kwargs": True, + "return_video_metadata": True, + } + image_processor = getattr(self._processor, "image_processor", None) + if image_processor is not None and hasattr(image_processor, "patch_size"): + process_kwargs["image_patch_size"] = image_processor.patch_size + + image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs) + + videos: list[Any] | None = None + video_metadatas: list[Any] | None = None + if video_inputs: + if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2: + videos_seq, metadatas_seq = zip(*video_inputs, strict=False) + videos = list(videos_seq) + video_metadatas = list(metadatas_seq) + else: + videos = list(video_inputs) + + processor_kwargs: dict[str, Any] = { + "text": texts, + "images": image_inputs, + "padding": True, + "truncation": False, + "max_length": self.max_length, + "return_tensors": "pt", + "do_resize": False, + } + if videos is not None: + processor_kwargs["videos"] = videos + if video_metadatas is not None: + processor_kwargs["video_metadata"] = video_metadatas + if video_kwargs: + processor_kwargs.update(video_kwargs) + + encoded = self._processor(**processor_kwargs) + + # Write Robometer-specific token ids and the video patch merge size into + # the encoded batch so `RobometerRewardModel` doesn't need its own + # tokenizer at inference (EO1-style separation: the processor owns the + # tokenizer, the model owns the backbone and heads). + tokenizer = self._processor.tokenizer + encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>") + encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>") + encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>") + video_processor = getattr(self._processor, "video_processor", None) + encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14)) + return encoded + + def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]: + pil_frames = _frames_to_pil(frames) + prompt = PROGRESS_PROMPT.format(task=task) + content: list[dict[str, Any]] = [{"type": "text", "text": prompt}] + + if self.use_multi_image: + for image in pil_frames: + content.append({"type": "image", "image": image}) + if self.use_per_frame_progress_token: + content.append({"type": "text", "text": "<|prog_token|>"}) + else: + content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0}) + + return [{"role": "user", "content": content}] + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + def get_config(self) -> dict[str, Any]: + return { + "base_model_id": self.base_model_id, + "image_key": self.image_key, + "task_key": self.task_key, + "default_task": self.default_task, + "max_frames": self.max_frames, + "use_multi_image": self.use_multi_image, + "use_per_frame_progress_token": self.use_per_frame_progress_token, + "max_length": self.max_length, + } + + +def make_robometer_pre_post_processors( + config: RobometerConfig, + dataset_stats: dict[str, dict[str, Any]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Pipeline that pre-encodes frames + task into Qwen-VL tensors. + + The preprocessor adds a batch dimension if needed, runs Robometer's + encoder, and moves everything to the configured device. The + postprocessor is the identity since Robometer outputs a single reward + tensor. + """ + del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor. + + preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[ + AddBatchDimensionProcessorStep(), + RobometerEncoderProcessorStep( + base_model_id=config.base_model_id, + image_key=config.image_key, + task_key=config.task_key, + default_task=config.default_task, + max_frames=config.max_frames, + use_multi_image=config.use_multi_image, + use_per_frame_progress_token=config.use_per_frame_progress_token, + ), + DeviceProcessorStep(device=config.device or "cpu"), + ], + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ) + postprocessor = PolicyProcessorPipeline( + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + ) + return preprocessor, postprocessor diff --git a/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md b/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md index 933bf7586..cc3a2c23a 100644 --- a/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md +++ b/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md @@ -13,6 +13,8 @@ A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable. {% elif model_name == "sarm" %} A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable. +{% elif model_name == "robometer" %} +ROBOMETER is a general-purpose video-language robotic reward model built on a fine-tuned Qwen3-VL-4B backbone with progress, preference, and success heads. Given a trajectory video and a task description, it predicts dense, frame-level task progress in [0, 1] and frame-level success probabilities for downstream robot learning, including offline RL, online RL, data filtering and retrieval, and automated failure detection. {% else %} _Reward model type not recognized — please update this template._ {% endif %} diff --git a/tests/rewards/test_modeling_robometer.py b/tests/rewards/test_modeling_robometer.py new file mode 100644 index 000000000..19aba13fa --- /dev/null +++ b/tests/rewards/test_modeling_robometer.py @@ -0,0 +1,340 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Robometer reward model.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from lerobot.configs.rewards import RewardModelConfig +from lerobot.rewards.factory import get_reward_model_class, make_reward_model_config +from lerobot.rewards.robometer import RobometerConfig +from lerobot.rewards.robometer.configuration_robometer import ROBOMETER_SPECIAL_TOKENS +from lerobot.rewards.robometer.modeling_robometer import ( + ROBOMETER_FEATURE_PREFIX, + convert_bins_to_continuous, + decode_progress_outputs, +) +from tests.utils import skip_if_package_missing + +# Length of the fake tokenizer used in `_patch_build`. The deterministic +# resize target derived in ``RobometerConfig.__post_init__`` is therefore +# ``_FAKE_TOKENIZER_LEN + len(ROBOMETER_SPECIAL_TOKENS)``. +_FAKE_TOKENIZER_LEN = 100 +_EXPECTED_RESIZED_VOCAB = _FAKE_TOKENIZER_LEN + len(ROBOMETER_SPECIAL_TOKENS) + + +class _FakeQwenConfig: + """Stand-in for a Qwen3-VL config (the `model.config` attribute). + + ``to_dict`` matches HF's ``PretrainedConfig.to_dict`` closely enough for + ``RobometerConfig.__post_init__`` to snapshot a meaningful ``vlm_config`` + into the saved ``config.json`` and for the reload path to round-trip + through ``AutoConfig.for_model``. + """ + + def __init__(self, hidden_dim: int = 8, vocab_size: int = _FAKE_TOKENIZER_LEN) -> None: + # `vocab_size` here is the *pre-resize* value the fake backbone advertises. + # `__post_init__` is expected to overwrite it with `len(tokenizer) + 5`. + self.text_config = SimpleNamespace(hidden_size=hidden_dim, vocab_size=vocab_size) + self._hidden_dim = hidden_dim + self._vocab_size = vocab_size + + def to_dict(self) -> dict: + return { + "model_type": "fake_qwen", + "text_config": { + "hidden_size": self._hidden_dim, + "vocab_size": self._vocab_size, + }, + } + + +class _FakeEmbeddings(torch.nn.Module): + def __init__(self, num_embeddings: int = _FAKE_TOKENIZER_LEN) -> None: + super().__init__() + self.num_embeddings = num_embeddings + + +class _FakeBaseModel(torch.nn.Module): + """Stand-in for the Qwen3-VL backbone during tests. + + Provides the minimum surface `RobometerRewardModel.__init__` and + `_compute_rbm_logits` rely on: a `parameters()` iterator (for dtype + + device), a `config.text_config.hidden_size`, a `config.to_dict()` so + `_save_pretrained` can snapshot `vlm_config`, + `get_input_embeddings()` / `resize_token_embeddings()` so the fresh-init + embed resize is a no-op, and a forward that returns a `SimpleNamespace` + with a `hidden_states` tuple. + """ + + def __init__(self, hidden_dim: int = 8) -> None: + super().__init__() + self._param = torch.nn.Parameter(torch.zeros(1)) + self.hidden_dim = hidden_dim + self.config = _FakeQwenConfig(hidden_dim) + self._embeddings = _FakeEmbeddings() + + def get_input_embeddings(self) -> _FakeEmbeddings: + return self._embeddings + + def resize_token_embeddings(self, new_size: int) -> None: + self._embeddings.num_embeddings = new_size + + def forward(self, **kwargs): # noqa: ARG002 - intentional kwargs sink + input_ids = kwargs["input_ids"] + return SimpleNamespace( + hidden_states=(torch.zeros(input_ids.shape[0], input_ids.shape[1], self.hidden_dim),), + last_hidden_state=torch.zeros(input_ids.shape[0], input_ids.shape[1], self.hidden_dim), + ) + + +class _FakeTokenizer: + """Minimal stand-in for an HF tokenizer. + + ``RobometerConfig.__post_init__`` uses ``len(tokenizer)`` to compute the + deterministic resize target ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``, + so a working ``__len__`` is all we need. + """ + + def __init__(self, length: int = _FAKE_TOKENIZER_LEN) -> None: + self._length = length + + def __len__(self) -> int: + return self._length + + +def _patch_build(monkeypatch) -> None: + """Stub out the HF AutoX calls so Robometer construction stays cheap in tests. + + Covers (EO-1 style — no model-side override hooks): + * ``AutoConfig.from_pretrained`` (config side) — used by + ``RobometerConfig.__post_init__`` to snapshot the backbone config. + * ``AutoTokenizer.from_pretrained`` (config side) — used by + ``__post_init__`` to compute ``len(tokenizer) + 5``. + * ``AutoConfig.for_model`` — used by + ``RobometerConfig.vlm_backbone_config`` when rebuilding for ``from_config``. + * ``AutoModelForImageTextToText.from_pretrained`` — fresh-training path + (``pretrained_path is None``). + * ``AutoModelForImageTextToText.from_config`` — checkpoint-reload path + (``pretrained_path`` is set). + """ + from lerobot.rewards.robometer import configuration_robometer, modeling_robometer + + monkeypatch.setattr( + modeling_robometer.AutoModelForImageTextToText, + "from_pretrained", + lambda *args, **kwargs: _FakeBaseModel(hidden_dim=8), + ) + monkeypatch.setattr( + modeling_robometer.AutoModelForImageTextToText, + "from_config", + lambda *args, **kwargs: _FakeBaseModel(hidden_dim=8), + ) + monkeypatch.setattr( + configuration_robometer.AutoConfig, + "for_model", + lambda *args, **kwargs: _FakeQwenConfig(hidden_dim=8), + ) + monkeypatch.setattr( + configuration_robometer.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: _FakeQwenConfig(hidden_dim=8), + ) + monkeypatch.setattr( + configuration_robometer.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _FakeTokenizer(length=_FAKE_TOKENIZER_LEN), + ) + + +def _make_batch(features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Build a `compute_reward`-ready batch using Robometer's namespaced keys.""" + return {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in features.items()} + + +@skip_if_package_missing("transformers") +def test_robometer_config_registered(monkeypatch): + _patch_build(monkeypatch) + assert "robometer" in RewardModelConfig.get_known_choices() + assert RewardModelConfig.get_choice_class("robometer") is RobometerConfig + assert isinstance(make_reward_model_config("robometer", device="cpu"), RobometerConfig) + + +def test_robometer_factory_returns_in_tree_class(): + from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel + + assert get_reward_model_class("robometer") is RobometerRewardModel + + +def test_convert_bins_to_continuous_returns_expected_values(): + # Two frames: first peaks at bin 0 (center 0.0), second peaks at bin 9 (center 1.0). + bin_logits = torch.full((2, 10), -10.0) + bin_logits[0, 0] = 10.0 + bin_logits[1, -1] = 10.0 + values = convert_bins_to_continuous(bin_logits) + assert values.shape == (2,) + assert torch.allclose(values, torch.tensor([0.0, 1.0]), atol=1e-3) + + +def test_decode_progress_outputs_returns_last_frame_values(): + progress = torch.tensor([[0.1, 0.9], [0.4, 0.6]]) + success_logits = torch.tensor([[0.0, 5.0], [0.0, -5.0]]) + + outputs = decode_progress_outputs(progress, success_logits, is_discrete_mode=False) + + assert outputs["progress_pred"] == [pytest.approx([0.1, 0.9]), pytest.approx([0.4, 0.6])] + assert outputs["success_probs"][0][-1] == pytest.approx(torch.sigmoid(torch.tensor(5.0)).item(), abs=1e-3) + assert outputs["success_probs"][1][-1] == pytest.approx( + torch.sigmoid(torch.tensor(-5.0)).item(), abs=1e-3 + ) + + +def test_decode_progress_outputs_discrete_mode_softmaxes_over_bins(): + # 2 frames, peaks at bin 0 and bin 9 → continuous predictions 0.0 and 1.0 + bin_logits = torch.full((1, 2, 10), -10.0) + bin_logits[0, 0, 0] = 10.0 + bin_logits[0, 1, -1] = 10.0 + + outputs = decode_progress_outputs(bin_logits, success_logits=None, is_discrete_mode=True) + + assert outputs["success_probs"] == [] + assert outputs["progress_pred"][0] == pytest.approx([0.0, 1.0], abs=1e-3) + + +@skip_if_package_missing("transformers") +def test_robometer_post_init_overwrites_vocab_size_with_tokenizer_length(monkeypatch): + """``RobometerConfig.__post_init__`` must overwrite the backbone's stale + ``text_config.vocab_size`` (which on the real Qwen3-VL config is the + padded embedding size, ``151,936``) with ``len(tokenizer) + 5``. This is + the contract that makes the published ``Robometer-4B`` checkpoint load + byte-equivalently.""" + _patch_build(monkeypatch) + + cfg = RobometerConfig(device="cpu", progress_loss_type="l2") + + assert cfg.vlm_config["text_config"]["vocab_size"] == _EXPECTED_RESIZED_VOCAB + + +@skip_if_package_missing("transformers") +def test_robometer_compute_reward_reads_pre_encoded_inputs(monkeypatch): + from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel + + progress = torch.tensor([[0.1, 0.9], [0.4, 0.6]]) + success_logits = torch.tensor([[0.0, 5.0], [0.0, -5.0]]) + _patch_build(monkeypatch) + + cfg = RobometerConfig(device="cpu", reward_output="progress", progress_loss_type="l2") + model = RobometerRewardModel(cfg) + # Bypass the Qwen3-VL forward + head extraction with deterministic logits. + monkeypatch.setattr(model, "_compute_rbm_logits", lambda _inputs: (progress, success_logits)) + + batch = _make_batch({"input_ids": torch.zeros(2, 2, dtype=torch.long)}) + rewards = model.compute_reward(batch) + + assert torch.allclose(rewards, torch.tensor([0.9, 0.6])) + + +@skip_if_package_missing("transformers") +def test_robometer_compute_reward_can_return_binary_success(monkeypatch): + from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel + + progress = torch.tensor([[0.1, 0.9], [0.4, 0.6]]) + success_logits = torch.tensor([[0.0, 5.0], [0.0, -5.0]]) # sigmoid(5) > 0.5; sigmoid(-5) < 0.5 + _patch_build(monkeypatch) + + cfg = RobometerConfig( + device="cpu", + reward_output="success", + success_threshold=0.5, + progress_loss_type="l2", + ) + model = RobometerRewardModel(cfg) + monkeypatch.setattr(model, "_compute_rbm_logits", lambda _inputs: (progress, success_logits)) + + batch = _make_batch({"input_ids": torch.zeros(2, 2, dtype=torch.long)}) + rewards = model.compute_reward(batch) + + assert torch.equal(rewards, torch.tensor([1.0, 0.0])) + + +@skip_if_package_missing("transformers") +def test_robometer_compute_reward_errors_when_inputs_missing(monkeypatch): + from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel + + _patch_build(monkeypatch) + + cfg = RobometerConfig(device="cpu", progress_loss_type="l2") + model = RobometerRewardModel(cfg) + + with pytest.raises(KeyError, match=r"observation\.robometer\.input_ids"): + model.compute_reward({}) + + +@skip_if_package_missing("transformers") +def test_robometer_save_pretrained_roundtrips(monkeypatch, tmp_path): + """Saving and reloading a Robometer model in LeRobot HF format must produce + a single ``model.safetensors`` + ``config.json`` (no Hydra ``config.yaml``), + must round-trip user-tunable config fields, and must persist all three + prediction heads (``progress_head``, ``success_head``, ``preference_head``) + so the published ``Robometer-4B`` checkpoint loads byte-equivalently. + """ + from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE + from safetensors.torch import load_file + + from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel + + _patch_build(monkeypatch) + cfg = RobometerConfig( + device="cpu", + pretrained_path="robometer/Robometer-4B", + # Knobs the user might tweak — must survive the round-trip. + image_key="observation.images.cam_top", + task_key="task", + reward_output="success", + success_threshold=0.7, + progress_loss_type="l2", + ) + model = RobometerRewardModel(cfg) + model.save_pretrained(str(tmp_path)) + + # Exactly the files LeRobot's HubMixin promises. + assert (tmp_path / CONFIG_NAME).exists() + assert (tmp_path / SAFETENSORS_SINGLE_FILE).exists() + assert not (tmp_path / "config.yaml").exists() # we want HF-style, not Hydra + + # All three heads must be present in the saved safetensors. The preference + # head is unused at inference but the published checkpoint expects its + # rows — losing it would silently break weight loading. + state = load_file(str(tmp_path / SAFETENSORS_SINGLE_FILE)) + assert any(k.startswith("progress_head.") for k in state), "progress_head weights missing" + assert any(k.startswith("success_head.") for k in state), "success_head weights missing" + assert any(k.startswith("preference_head.") for k in state), "preference_head weights missing" + + # Reload from the local directory: no Hub fetch, no YAML overlay. The + # base class drives subclass dispatch via the `type` field in config.json. + reloaded_cfg = RewardModelConfig.from_pretrained(str(tmp_path)) + assert isinstance(reloaded_cfg, RobometerConfig) + reloaded_cfg.pretrained_path = str(tmp_path) # mimic lerobot-train's `validate()` + reloaded = RobometerRewardModel.from_pretrained(str(tmp_path), config=reloaded_cfg) + + assert reloaded.config.image_key == "observation.images.cam_top" + assert reloaded.config.task_key == "task" + assert reloaded.config.reward_output == "success" + assert reloaded.config.success_threshold == 0.7 + assert reloaded.config.progress_loss_type == "l2" # came back from config.json diff --git a/tests/rewards/test_robometer_processor.py b/tests/rewards/test_robometer_processor.py new file mode 100644 index 000000000..cba8ad564 --- /dev/null +++ b/tests/rewards/test_robometer_processor.py @@ -0,0 +1,354 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Robometer's pre-processing helpers and encoder step. + +Covers the pure helpers (``_video_to_numpy`` and ``_expand_tasks``) directly, +and exercises :class:`RobometerEncoderProcessorStep` with a stubbed +``AutoProcessor`` so we don't need to download Qwen-VL just to test the +dataclass plumbing (``transform_features`` / ``get_config``). + +The full ``__call__`` path that runs ``process_vision_info`` + the Qwen +processor is intentionally *not* covered here — it is essentially HF glue +that's exercised by the integration / parity scripts. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest +import torch + +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.rewards.robometer.processor_robometer import ( + PROGRESS_PROMPT, + _expand_tasks, + _frames_to_pil, + _video_to_numpy, +) +from tests.utils import skip_if_package_missing + + +def _skip_if_robometer_extras_missing(func): + """Apply both optional-dependency guards in one shot. + + ``RobometerEncoderProcessorStep.__post_init__`` calls + ``require_package("transformers", ...)`` *and* + ``require_package("qwen-vl-utils", ...)``, so both need to be present + before we can instantiate the step. + """ + func = skip_if_package_missing("qwen-vl-utils", import_name="qwen_vl_utils")(func) + func = skip_if_package_missing("transformers")(func) + return func + + +# --------------------------------------------------------------------------- +# _video_to_numpy — pure tensor → uint8 (T, H, W, C) conversion +# --------------------------------------------------------------------------- + + +def test_video_to_numpy_chw_float_is_converted_to_thwc_uint8(): + video = torch.rand(4, 3, 8, 8) # (T, C, H, W) floats in [0, 1] + array = _video_to_numpy(video, max_frames=None) + + assert array.shape == (4, 8, 8, 3) + assert array.dtype == np.uint8 + assert array.min() >= 0 and array.max() <= 255 + + +def test_video_to_numpy_already_thwc_uint8_passes_through(): + video = torch.randint(0, 256, (3, 8, 8, 3), dtype=torch.uint8) # (T, H, W, C) + array = _video_to_numpy(video, max_frames=None) + + assert array.shape == (3, 8, 8, 3) + assert array.dtype == np.uint8 + + +def test_video_to_numpy_max_frames_tail_crops_recent_frames(): + """``max_frames`` should keep the **last** K frames (most recent).""" + video = torch.zeros(10, 3, 4, 4) + for t in range(10): + video[t] = t / 9.0 # marker: 0 at t=0, ≈1 at t=9 + + array = _video_to_numpy(video, max_frames=3) + + assert array.shape == (3, 4, 4, 3) + # The first kept frame is t=7 → marker ≈ 7/9 → uint8 ≈ 198 + assert int(array[0, 0, 0, 0]) == int(round(7 / 9 * 255)) + # The last kept frame is t=9 → marker = 1.0 → uint8 = 255 + assert int(array[-1, 0, 0, 0]) == 255 + + +def test_video_to_numpy_rejects_3d_input(): + with pytest.raises(ValueError, match="Expected channel dim"): + _video_to_numpy(torch.zeros(4, 8, 8), max_frames=None) + + +def test_video_to_numpy_floats_above_one_pass_through_without_rescaling(): + """If ``array.max() > 1`` the helper assumes the tensor is already in the + [0, 255] range (uint8-as-float), so values pass through unchanged.""" + video = torch.full((1, 3, 2, 2), 5.0) + array = _video_to_numpy(video, max_frames=None) + + assert array.shape == (1, 2, 2, 3) + assert int(array.max()) == 5 + + +def test_video_to_numpy_clips_very_large_floats_to_uint8_max(): + """Out-of-uint8-range floats are clipped at 255 before the cast.""" + video = torch.full((1, 3, 2, 2), 300.0) + array = _video_to_numpy(video, max_frames=None) + + assert int(array.max()) == 255 + + +# --------------------------------------------------------------------------- +# _expand_tasks — string / list / tuple broadcasting to batch size +# --------------------------------------------------------------------------- + + +def test_expand_tasks_string_is_broadcast_to_batch_size(): + assert _expand_tasks("pick up", batch_size=3, default=None) == ["pick up", "pick up", "pick up"] + + +def test_expand_tasks_list_of_matching_size_passes_through(): + assert _expand_tasks(["a", "b", "c"], batch_size=3, default=None) == ["a", "b", "c"] + + +def test_expand_tasks_tuple_is_normalised_to_list(): + assert _expand_tasks(("a", "b"), batch_size=2, default=None) == ["a", "b"] + + +def test_expand_tasks_single_element_list_is_broadcast(): + assert _expand_tasks(["only one"], batch_size=3, default=None) == ["only one"] * 3 + + +def test_expand_tasks_size_mismatch_raises(): + with pytest.raises(ValueError, match="Expected 3 tasks"): + _expand_tasks(["a", "b"], batch_size=3, default=None) + + +def test_expand_tasks_missing_uses_default(): + assert _expand_tasks(None, batch_size=2, default="fallback") == ["fallback", "fallback"] + + +def test_expand_tasks_missing_without_default_raises(): + with pytest.raises(KeyError, match="task description"): + _expand_tasks(None, batch_size=1, default=None) + + +def test_expand_tasks_wrong_type_raises(): + with pytest.raises(TypeError, match="must be a string or list"): + _expand_tasks(42, batch_size=1, default=None) + + +# --------------------------------------------------------------------------- +# _frames_to_pil — uint8 (T, H, W, C) → list[PIL.Image] +# --------------------------------------------------------------------------- + + +def test_frames_to_pil_returns_one_image_per_frame(): + frames = np.zeros((4, 8, 8, 3), dtype=np.uint8) + images = _frames_to_pil(frames) + + assert len(images) == 4 + assert all(img.size == (8, 8) for img in images) + + +def test_frames_to_pil_casts_floats_to_uint8(): + frames = np.full((2, 4, 4, 3), 200.0, dtype=np.float32) + images = _frames_to_pil(frames) + + assert len(images) == 2 + # PIL converted from clipped uint8 - sanity check pixel values come through. + assert np.asarray(images[0]).dtype == np.uint8 + + +def test_frames_to_pil_rejects_non_4d_input(): + with pytest.raises(ValueError, match=r"\(T,H,W,C\)"): + _frames_to_pil(np.zeros((4, 8, 8), dtype=np.uint8)) + + +# --------------------------------------------------------------------------- +# Encoder step plumbing — exercise dataclass surface with a stubbed AutoProcessor +# --------------------------------------------------------------------------- + + +class _FakeTokenizer: + """Tokenizer surface the encoder step touches in ``__post_init__``.""" + + def __init__(self) -> None: + self.pad_token: str | None = None + self.eos_token = "<|endoftext|>" + self._vocab: dict[str, int] = {"<|endoftext|>": 0} + self.added: list[str] = [] + + def get_vocab(self) -> dict[str, int]: + return self._vocab + + def add_special_tokens(self, payload: dict[str, Any]) -> int: + for token in payload.get("additional_special_tokens", []): + if token not in self._vocab: + self._vocab[token] = len(self._vocab) + self.added.append(token) + return len(self.added) + + +class _FakeAutoProcessor: + """Stand-in returned by ``AutoProcessor.from_pretrained`` during tests.""" + + def __init__(self) -> None: + self.tokenizer = _FakeTokenizer() + self.image_processor = None + self.video_processor = None + + @classmethod + def from_pretrained(cls, *args, **kwargs): # noqa: ARG003 + return cls() + + +def _build_step(monkeypatch, **overrides): + from lerobot.rewards.robometer import processor_robometer + + monkeypatch.setattr(processor_robometer, "AutoProcessor", _FakeAutoProcessor) + + return processor_robometer.RobometerEncoderProcessorStep(**overrides) + + +@_skip_if_robometer_extras_missing +def test_encoder_step_registers_special_tokens_on_tokenizer(monkeypatch): + """``__post_init__`` must register Robometer's five special tokens on the + tokenizer that ships with the chosen Qwen-VL checkpoint.""" + from lerobot.rewards.robometer.configuration_robometer import ROBOMETER_SPECIAL_TOKENS + + step = _build_step(monkeypatch) + + vocab = step._processor.tokenizer.get_vocab() + for token in ROBOMETER_SPECIAL_TOKENS: + assert token in vocab, f"{token} not registered on the tokenizer" + + +@_skip_if_robometer_extras_missing +def test_encoder_step_sets_pad_token_to_eos_when_missing(monkeypatch): + """Qwen tokenizers ship without a pad token; the step must reuse EOS so + batched processing doesn't crash on padding.""" + step = _build_step(monkeypatch) + + assert step._processor.tokenizer.pad_token == "<|endoftext|>" + + +@_skip_if_robometer_extras_missing +def test_encoder_step_get_config_roundtrips_user_fields(monkeypatch): + """``get_config`` must serialise every user-tunable field — these are what + the processor pipeline saves under ``preprocessor_config.json``.""" + step = _build_step( + monkeypatch, + base_model_id="Qwen/Qwen3-VL-4B-Instruct", + image_key="observation.images.cam_top", + task_key="task", + default_task="do the thing", + max_frames=12, + use_multi_image=True, + use_per_frame_progress_token=True, + max_length=2048, + ) + + cfg = step.get_config() + assert cfg == { + "base_model_id": "Qwen/Qwen3-VL-4B-Instruct", + "image_key": "observation.images.cam_top", + "task_key": "task", + "default_task": "do the thing", + "max_frames": 12, + "use_multi_image": True, + "use_per_frame_progress_token": True, + "max_length": 2048, + } + + +@_skip_if_robometer_extras_missing +def test_encoder_step_transform_features_is_identity(monkeypatch): + """The encoder step writes Qwen tensors into ``observation`` at call time, + but it does **not** advertise new typed features at pipeline-build time — + the downstream model consumes them via the ``ROBOMETER_FEATURE_PREFIX`` + namespace, not via the typed feature map. + """ + step = _build_step(monkeypatch) + + features = { + PipelineFeatureType.OBSERVATION: { + "observation.images.top": PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL), + } + } + assert step.transform_features(features) == features + + +@_skip_if_robometer_extras_missing +def test_encoder_step_build_conversation_inserts_prog_token_per_frame(monkeypatch): + """In multi-image mode with per-frame progress tokens, the conversation + must alternate ``image`` and ``<|prog_token|>`` text entries, one pair + per frame, after the task prompt.""" + step = _build_step( + monkeypatch, + use_multi_image=True, + use_per_frame_progress_token=True, + ) + + frames = np.zeros((3, 8, 8, 3), dtype=np.uint8) + conversation = step._build_conversation(frames, task="pick up the cube") + + assert len(conversation) == 1 and conversation[0]["role"] == "user" + content = conversation[0]["content"] + + # First entry is the task prompt. + assert content[0] == {"type": "text", "text": PROGRESS_PROMPT.format(task="pick up the cube")} + + # Then 3 (image, <|prog_token|>) pairs. + expected_tail = [ + item + for _ in range(3) + for item in ( + {"type": "image"}, # value asserted below + {"type": "text", "text": "<|prog_token|>"}, + ) + ] + assert len(content) == 1 + len(expected_tail) + for got, exp in zip(content[1:], expected_tail, strict=True): + assert got["type"] == exp["type"] + if exp["type"] == "text": + assert got["text"] == exp["text"] + + +@_skip_if_robometer_extras_missing +def test_encoder_step_build_conversation_video_mode_uses_single_video_entry(monkeypatch): + """When ``use_multi_image=False``, frames are bundled into a single + ``video`` content entry instead of individual ``image`` entries.""" + step = _build_step( + monkeypatch, + use_multi_image=False, + use_per_frame_progress_token=False, + ) + + frames = np.zeros((4, 8, 8, 3), dtype=np.uint8) + conversation = step._build_conversation(frames, task="pour the water") + + content = conversation[0]["content"] + # Exactly two entries: the prompt and one video entry. + assert len(content) == 2 + assert content[0]["type"] == "text" + assert content[1]["type"] == "video" + # The video entry carries all four frames. + assert len(content[1]["video"]) == 4 diff --git a/uv.lock b/uv.lock index 408a9a351..ea3bf4443 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')", @@ -1142,7 +1142,7 @@ name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "numpy", marker = "(platform_machine != 'arm64' and platform_machine != 's390x' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -2972,6 +2972,11 @@ qwen-vl-utils-dep = [ reachy2 = [ { name = "reachy2-sdk" }, ] +robometer = [ + { name = "peft" }, + { name = "qwen-vl-utils" }, + { name = "transformers" }, +] robstride = [ { name = "python-can" }, ] @@ -3122,6 +3127,7 @@ requires-dist = [ { name = "lerobot", extras = ["peft"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" }, + { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'robometer'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["phone"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["pi"], marker = "extra == 'all'" }, @@ -3139,6 +3145,7 @@ requires-dist = [ { name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'lekiwi'" }, { name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" }, + { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'robometer'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" }, @@ -3160,6 +3167,7 @@ requires-dist = [ { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" }, + { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'robometer'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" }, @@ -3227,7 +3235,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt"