Compare commits

..

2 Commits

Author SHA1 Message Date
Khalil Meftah fa3eb9fce3 test(rewards): add unit tests for distributional value function model 2026-06-10 16:07:43 +02:00
Khalil Meftah 500c91ba92 feat(rewards): introduce distributional value function model
- Added a new distributional value function (DistributionalVF) model for RECAP, including its configuration, modeling, and processor components.
- Updated the rewards factory to support the new model type.
- Updated  to include the new model in the dependencies.
2026-06-10 15:24:50 +02:00
13 changed files with 1839 additions and 421 deletions
+5 -3
View File
@@ -214,9 +214,10 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
recap = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
@@ -231,9 +232,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
@@ -296,6 +297,7 @@ all = [
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[recap]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
+1 -7
View File
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
+1 -21
View File
@@ -481,10 +481,8 @@ def reencode_video(
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
start_time_s: float | None = None,
end_time_s: float | None = None,
) -> None:
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
"""Re-encode a video file using the given encoder configuration.
Args:
input_video_path: Existing video file to read.
@@ -493,17 +491,10 @@ def reencode_video(
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
start_time_s: When set, trim the output to start at this timestamp (seconds).
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
"""
camera_encoder = camera_encoder or camera_encoder_defaults()
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
@@ -535,10 +526,6 @@ def reencode_video(
width = int(in_stream.width)
height = int(in_stream.height)
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
if start_time_s is not None:
src.seek(int(start_time_s * av.time_base), backward=True)
with av.open(
tmp_output_video_path,
mode="w",
@@ -552,14 +539,7 @@ def reencode_video(
out_stream.height = height
for frame in src.decode(in_stream):
frame_time_s = frame.time
if start_time_s is not None and frame_time_s < start_time_s:
continue
if end_time_s is not None and frame_time_s >= end_time_s:
break
frame = frame.reformat(width=width, height=height, format=pix_fmt)
if start_time_s is not None:
frame.pts = None # reset timestamps so the trimmed output starts at t=0
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
+4
View File
@@ -13,6 +13,9 @@
# limitations under the License.
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig as DistributionalVFConfig,
)
from .factory import (
get_reward_model_class as get_reward_model_class,
make_reward_model as make_reward_model,
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
__all__ = [
# Configuration classes
"DistributionalVFConfig",
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
@@ -0,0 +1,23 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_distributional_value_function import DistributionalVFConfig
from .modeling_distributional_value_function import DistributionalVFRewardModel
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
__all__ = [
"DistributionalVFConfig",
"DistributionalVFRewardModel",
"make_distributional_vf_pre_post_processors",
]
@@ -0,0 +1,108 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weights initialized from a pre-trained PI05 actor checkpoint.
"""
from dataclasses import dataclass, field
from lerobot.configs import FeatureType, NormalizationMode
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
@RewardModelConfig.register_subclass("distributional_value_function")
@dataclass
class DistributionalVFConfig(RewardModelConfig):
"""Configuration for RECAP's distributional value function.
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
(Eq. 1 in the paper).
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
about 670M params and initialize from the PI05 actor checkpoint.
"""
# Backbone
paligemma_variant: str = "gemma_2b"
num_hidden_layers: int = 6
num_vision_layers: int = 13
# Distributional head
num_value_bins: int = 201
value_support_min: float = -1.0
value_support_max: float = 0.0
hl_gauss_sigma_ratio: float = 5.0
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
target_method: str = "hl_gauss"
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
# When False, terminal states use the same target method as non-terminal states.
use_one_hot_terminal: bool = True
# Image
image_resolution: tuple[int, int] = (224, 224)
# Tokenizer
tokenizer_max_length: int = 64
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
# Pass a PI05 checkpoint path or Hub repo_id here.
# After training, load the value function with RewardModel.from_pretrained() instead.
init_from_actor_path: str = ""
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
}
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=3e-4,
weight_decay=1e-4,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
if not self.input_features:
return
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
if not has_image:
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
@@ -0,0 +1,567 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Inputs: single image observation + task text prompt ("Task: {task}.")
Outputs: softmax distribution over value bins; expected value E[V] for inference.
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weight initialization: vision tower, multi-modal projector, token embeddings, and
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.import_utils import _transformers_available, require_package
from .configuration_distributional_value_function import DistributionalVFConfig
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaRMSNorm,
_gated_residual,
_get_pi_gemma_decoder_layer_base,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
PaliGemmaForConditionalGenerationWithPiGemma = None
PiGemmaRMSNorm = None
_gated_residual = None
_get_pi_gemma_decoder_layer_base = None
PALIGEMMA_VOCAB_SIZE = 257152
class DistributionalVFRewardModel(PreTrainedRewardModel):
"""Distributional value function model for RECAP.
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
per-task normalized Monte Carlo returns.
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
causal attention, [CLS] token, and Linear(D, num_bins) value head.
The expected value is E[V] = sum(softmax(logits) * bin_centers).
"""
name = "distributional_value_function"
config_class = DistributionalVFConfig
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
require_package("transformers", extra="recap")
super().__init__(config)
self.config = config
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
# Get base dimensions from the paligemma variant (OpenPI config format)
base_config = get_gemma_config(config.paligemma_variant)
hidden_dim = base_config.width
mlp_dim = base_config.mlp_dim
num_layers = config.num_hidden_layers
# HuggingFace GemmaConfig for transformer layers
gemma_config = CONFIG_MAPPING["gemma"](
head_dim=base_config.head_dim,
hidden_size=hidden_dim,
intermediate_size=mlp_dim,
num_attention_heads=base_config.num_heads,
num_hidden_layers=num_layers,
num_key_value_heads=base_config.num_kv_heads,
vocab_size=PALIGEMMA_VOCAB_SIZE,
hidden_activation="gelu_pytorch_tanh",
)
self.gemma_config = gemma_config
self.hidden_dim = hidden_dim
self.num_value_bins = config.num_value_bins
# Single learned [CLS] token for value prediction
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
# Value projection head: Linear(hidden_dim, num_bins)
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
# Transformer layers (overwritten by _initialize_from_actor on first run)
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
)
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
# PaliGemmaConfig wraps both vision and text configs into a single model
paligemma_config = CONFIG_MAPPING["paligemma"]()
paligemma_config.text_config = gemma_config
paligemma_config.vision_config.image_size = config.image_resolution[0]
paligemma_config.vision_config.intermediate_size = 4304
paligemma_config.vision_config.projection_dim = 2048
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
self.vision_tower = paligemma_full.model.vision_tower
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
self.token_embedding = paligemma_full.model.language_model.embed_tokens
del paligemma_full
# Truncate vision tower to num_vision_layers
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
vision_encoder = self.vision_tower.vision_model.encoder
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
# Bin support: evenly spaced centers from value_support_min to value_support_max
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
self.register_buffer("bin_centers", bin_centers, persistent=False)
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
# Overwrite with pre-trained PI05 actor weights (first training run only)
if config.init_from_actor_path:
self._initialize_from_actor()
def _initialize_from_actor(self) -> None:
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
Called on first training run only (when init_from_actor_path is set).
"""
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
actor_model = actor_policy.model
paligemma_model = actor_model.paligemma_with_expert.paligemma
source_language_model = paligemma_model.model.language_model
# Transformer components
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
num_layers = self.gemma_config.num_hidden_layers
for i in range(num_layers):
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
self.norm.load_state_dict(source_language_model.norm.state_dict())
# Vision tower (truncate source first, then copy)
source_vision_tower = paligemma_model.model.vision_tower
if hasattr(source_vision_tower, "vision_model") and hasattr(
source_vision_tower.vision_model, "encoder"
):
source_encoder = source_vision_tower.vision_model.encoder
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
# Multi-modal projector
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
# Token embedding table
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
del actor_policy
def embed_image(self, image: Tensor) -> Tensor:
"""Embed images using the value function's SigLIP vision tower.
Args:
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
Returns:
[batch_size, num_patches, hidden_dim] projected image features.
"""
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.vision_tower(image, return_dict=True)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
image_features = image_features / (self.hidden_dim**0.5)
if image_features.dtype != out_dtype:
image_features = image_features.to(out_dtype)
return image_features
def embed_text(self, token_ids: Tensor) -> Tensor:
"""Embed text token IDs using the value function's token embedding table.
Args:
token_ids: [batch_size, seq_len] integer token IDs
Returns:
[batch_size, seq_len, hidden_dim] text embeddings
"""
return self.token_embedding(token_ids)
def _get_cls_embedding(self, batch_size: int) -> Tensor:
"""Get [CLS] token embedding expanded to batch size.
Args:
batch_size: number of samples in the batch.
Returns:
[batch_size, 1, hidden_dim] learned [CLS] embedding.
"""
return self.cls_embedding.expand(batch_size, -1, -1)
def forward_value(
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
) -> dict[str, Tensor]:
"""Core forward pass through the distributional value function.
Args:
vision_features: [batch_size, num_patches, hidden_dim]
text_embeddings: [batch_size, seq_len, hidden_dim]
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
Returns:
logits: [batch_size, num_value_bins]
probs: [batch_size, num_value_bins]
value: [batch_size, 1]
"""
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
batch_size = text_embeddings.shape[0]
device = text_embeddings.device
# Build sequence: [vision, text, CLS]
cls_embedding = self._get_cls_embedding(batch_size)
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
# Build causal attention mask
vision_len = vision_features.shape[1]
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
full_seq_len = full_padding_mask.shape[1]
# Causal mask
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
# Combine causal mask with padding mask
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
batch_size, 1, full_seq_len, full_seq_len
)
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
cos, sin = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
norm_output = layer.input_layernorm(hidden_states, cond=None)
if isinstance(norm_output, tuple):
hidden_states_normed, gate = norm_output
else:
hidden_states_normed, gate = norm_output, None
input_shape = hidden_states_normed.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
attention_output, _ = modeling_gemma.eager_attention_forward(
layer.self_attn,
query_states,
key_states,
value_states,
attention_mask,
layer.self_attn.scaling,
)
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
projected_attention = layer.self_attn.o_proj(attention_output)
if gate is not None:
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
else:
projected_attention = hidden_states + projected_attention
after_attention_residual = projected_attention.clone()
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
if isinstance(norm_output, tuple):
mlp_input, gate = norm_output
else:
mlp_input, gate = norm_output, None
mlp_output = layer.mlp(mlp_input)
if gate is not None:
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
else:
hidden_states = after_attention_residual + mlp_output
hidden_states = self.norm(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Extract [CLS] token (last position in the sequence)
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
# Value head: Linear(hidden_dim, num_bins) -> logits
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
value_probs = F.softmax(value_logits, dim=-1)
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
dim=-1, keepdim=True
)
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
"""HL-Gauss soft target distribution.
Places a Gaussian N(target, sigma^2) over the bin support and computes
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
Functions via Classification for Scalable Deep RL", Section 3.1.
arXiv:2403.03950
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
# Bin edges: half a bin-width outside the first/last center
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
self.num_value_bins - 1
)
support_edges = torch.linspace(
self.config.value_support_min - bin_width / 2,
self.config.value_support_max + bin_width / 2,
self.num_value_bins + 1,
device=target_value.device,
dtype=target_value.dtype,
)
# CDF of N(target, sigma^2) evaluated at each edge
cdf_at_edges = 0.5 * (
1.0
+ torch.erf(
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
/ (self.hl_gauss_sigma * math.sqrt(2))
)
) # [batch_size, num_bins + 1]
# Normalize: z = cdf(max_edge) - cdf(min_edge)
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
# Bin probabilities = differences of consecutive CDF values, normalized
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
return bin_probabilities
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
"""Dirac delta (C51) projection: split probability between two nearest bins.
Standard distributional RL projection from Bellemare et al. 2017.
"A Distributional Perspective on Reinforcement Learning"
arXiv:1707.06887
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
target_value = target_value.to(dtype=self.bin_centers.dtype)
bin_width = self.bin_centers[1] - self.bin_centers[0]
normalized_position = (target_value - self.config.value_support_min) / bin_width
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
weight_upper = normalized_position - lower_bin_idx.float()
weight_lower = upper_bin_idx.float() - normalized_position
same_bin = lower_bin_idx == upper_bin_idx
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
batch_size = target_value.shape[0]
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
batch_indices = torch.arange(batch_size, device=target_value.device)
target_distribution[batch_indices, lower_bin_idx] += weight_lower
target_distribution[batch_indices, upper_bin_idx] += weight_upper
return target_distribution
def one_hot_target(self, target_value: Tensor) -> Tensor:
"""One-hot target for terminal states (exact return, no smoothing).
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
nearest_bin_idx = torch.argmin(
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
)
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
def compute_target_distribution(
self,
target_value: Tensor,
is_terminal: Tensor,
method: str = "hl_gauss",
use_one_hot_terminal: bool = True,
) -> Tensor:
"""Compute target distribution using configured method.
Args:
target_value: [batch_size] scalar return targets
is_terminal: [batch_size] boolean terminal flags
method: "hl_gauss" or "dirac_delta"
use_one_hot_terminal: if True, terminal states get one-hot targets
(exact return, no smoothing). If False, all states use the same method.
Returns:
[batch_size, num_value_bins] target probability distribution
"""
if method == "hl_gauss":
base_distribution = self.hl_gauss_target(target_value)
elif method == "dirac_delta":
base_distribution = self.dirac_delta_target(target_value)
else:
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
if not use_one_hot_terminal:
return base_distribution
terminal_distribution = self.one_hot_target(target_value)
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Training forward pass — computes cross-entropy loss against MC return targets.
The batch is expected to be preprocessed by the processor pipeline.
Keys expected in batch:
- observation.images.*: [B, C, H, W] preprocessed images
- observation.language_tokens: [B, seq_len] tokenized task prompt
- observation.language_attention_mask: [B, seq_len] padding mask
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
- is_terminal: [B] boolean terminal flags
Returns:
(loss, output_dict) where loss is scalar cross-entropy
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
# Get first image key from batch
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
mc_return = batch["mc_return"]
is_terminal = batch["is_terminal"]
# Embed observations
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
# Forward through value function transformer
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
value_logits = vf_output["logits"]
predicted_value = vf_output["value"]
# Compute target distribution
target_distribution = self.compute_target_distribution(
mc_return,
is_terminal,
method=self.config.target_method,
use_one_hot_terminal=self.config.use_one_hot_terminal,
)
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
log_probs = F.log_softmax(value_logits, dim=-1)
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
output_dict = {
"loss": loss.item(),
"predicted_value_mean": predicted_value.mean().item(),
"mc_return_mean": mc_return.mean().item(),
}
return loss, output_dict
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute V(s) for a batch of observations. Used for advantage scoring.
Args:
batch: preprocessed batch with images and tokenized text
Returns:
[batch_size] tensor of predicted values V(s)
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
return vf_output["value"].squeeze(-1) # [batch_size]
@@ -0,0 +1,235 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
Training targets (mc_return, is_terminal) are NOT routed through the processor.
They are dataset columns read directly from the batch in the model's forward().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
)
from lerobot.processor.converters import to_tensor
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_distributional_value_function import DistributionalVFConfig
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
@dataclass
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
"""Format the task string for the distributional value function.
The value function receives only visual observations and task text.
Builds prompt: "Task: {task}."
"""
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
tasks = complementary_data.get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
if isinstance(tasks, str):
tasks = [tasks]
full_prompts = []
for task in tasks:
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
full_prompts.append(f"Task: {cleaned_text}.")
new_complementary_data = dict(complementary_data)
new_complementary_data[self.task_key] = full_prompts
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {"task_key": self.task_key}
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
@dataclass
class DistributionalVFImagePreprocessorStep(ProcessorStep):
"""Resize and normalize images for the value function's SigLIP vision tower.
Expects float images in [0, 1].
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
- Scale to [-1, 1] for SigLIP
"""
image_resolution: tuple[int, int] = (224, 224)
image_keys: tuple[str, ...] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
observation = transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
image_keys = self.image_keys or tuple(
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
)
if not image_keys:
raise KeyError(
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
)
new_observation = dict(observation)
for image_key in image_keys:
image = new_observation[image_key]
if not isinstance(image, Tensor):
image = to_tensor(image)
if image.dtype != torch.float32:
image = image.to(torch.float32)
is_channels_first = image.ndim == 4 and image.shape[1] == 3
if is_channels_first:
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != self.image_resolution:
image = resize_with_pad_torch(image, *self.image_resolution)
image = image * 2.0 - 1.0
if is_channels_first:
image = image.permute(0, 3, 1, 2)
new_observation[image_key] = image
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"image_resolution": self.image_resolution,
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
}
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
return tuple(
feature_name
for feature_name, feature in config.input_features.items()
if feature.type == FeatureType.VISUAL
)
def make_distributional_vf_pre_post_processors(
config: DistributionalVFConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create pre/post processors for the distributional value function.
Preprocessor steps:
1. Rename observations (no-op by default)
2. Add a batch dimension
3. Normalize features (images use identity, so they stay in [0, 1])
4. Format task prompt: "Task: {task}."
5. Tokenize with the PaliGemma tokenizer
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
7. Move tensors to the configured device
Training targets (mc_return, is_terminal) are not processed here.
The model reads them directly from the batch in forward().
The postprocessor is a no-op because the value function does not need
action postprocessing.
"""
image_keys = _visual_image_keys(config)
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DistributionalVFPrepareTaskPromptStep(),
TokenizerProcessorStep(
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DistributionalVFImagePreprocessorStep(
image_resolution=config.image_resolution,
image_keys=image_keys or None,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
+19
View File
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
return TOPRewardModel
elif name == "distributional_value_function":
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
elif reward_type == "distributional_value_function":
return DistributionalVFConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, DistributionalVFConfig):
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
return make_distributional_vf_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_reward_model_config(
+5 -15
View File
@@ -232,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -389,19 +386,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
-24
View File
@@ -114,30 +114,6 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
@@ -0,0 +1,518 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RECAP's distributional value function."""
from __future__ import annotations
import pytest
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_IMAGES
from tests.utils import skip_if_package_missing
BATCH_SIZE = 4
NUM_BINS = 201
IMAGE_KEY = f"{OBS_IMAGES}.top"
def _make_config(**overrides) -> DistributionalVFConfig:
defaults = {
"init_from_actor_path": "",
"device": "cpu",
"image_resolution": (224, 224),
}
defaults.update(overrides)
config = DistributionalVFConfig(**defaults)
config.input_features = {
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
}
return config
def _make_model():
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel(_make_config())
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
return {
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
"mc_return": torch.rand(batch_size, device=device) * -1.0,
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
}
def test_config_registered_in_reward_model_registry():
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
known = RewardModelConfig.get_known_choices()
assert "distributional_value_function" in known
def test_factory_returns_correct_class():
"""get_reward_model_class returns DistributionalVFRewardModel."""
from lerobot.rewards.factory import get_reward_model_class
cls = get_reward_model_class("distributional_value_function")
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
assert cls is DistributionalVFRewardModel
def test_make_reward_model_config_factory():
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
from lerobot.rewards.factory import make_reward_model_config
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
assert isinstance(config, DistributionalVFConfig)
assert config.num_value_bins == 101
@skip_if_package_missing("transformers")
def test_hl_gauss_sums_to_one():
"""HL-Gauss target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
dist = model.hl_gauss_target(targets)
assert dist.shape == (4, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_non_negative():
"""HL-Gauss target probabilities are all non-negative."""
model = _make_model()
targets = torch.linspace(-1.0, 0.0, 10)
dist = model.hl_gauss_target(targets)
assert (dist >= 0).all()
@skip_if_package_missing("transformers")
def test_hl_gauss_expected_value_matches():
"""E[V] under HL-Gauss distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.hl_gauss_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_handles_2d_input():
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
dist = model.hl_gauss_target(targets)
assert dist.shape == (2, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_sums_to_one():
"""Dirac delta target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
dist = model.dirac_delta_target(targets)
assert dist.shape == (5, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_at_most_two_nonzero():
"""Dirac delta places probability on at most two adjacent bins."""
model = _make_model()
targets = torch.tensor([-0.7523, -0.0013])
dist = model.dirac_delta_target(targets)
for i in range(2):
assert (dist[i] > 0).sum() <= 2
@skip_if_package_missing("transformers")
def test_dirac_delta_expected_value_matches():
"""E[V] under Dirac delta distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.dirac_delta_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_boundary_values_clamped():
"""Values outside support are clamped to boundary bins."""
model = _make_model()
targets = torch.tensor([-1.5, 0.5])
dist = model.dirac_delta_target(targets)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
assert dist[0, 0] == 1.0
assert dist[1, -1] == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_single_nonzero():
"""One-hot target has exactly one non-zero bin per sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
dist = model.one_hot_target(targets)
assert dist.shape == (4, NUM_BINS)
for i in range(4):
assert (dist[i] > 0).sum() == 1
assert dist[i].sum() == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_nearest_bin():
"""One-hot target activates the bin closest to the target value."""
model = _make_model()
targets = torch.tensor([-0.5])
dist = model.one_hot_target(targets)
hot_idx = dist[0].argmax()
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
@skip_if_package_missing("transformers")
def test_terminal_gets_one_hot():
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
is_terminal = torch.tensor([False, True, False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
)
for i in range(4):
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
assert (dist[1] > 0).sum() == 1
assert (dist[3] > 0).sum() == 1
assert (dist[0] > 0).sum() > 2
assert (dist[2] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_no_terminal_override_when_disabled():
"""When use_one_hot_terminal=False, terminal states use the base method."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3])
is_terminal = torch.tensor([False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
)
assert (dist[1] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_model_has_expected_components():
"""Model scaffold contains all architectural components."""
model = _make_model()
assert hasattr(model, "vision_tower")
assert hasattr(model, "multi_modal_projector")
assert hasattr(model, "token_embedding")
assert hasattr(model, "layers")
assert hasattr(model, "value_head")
assert hasattr(model, "cls_embedding")
assert hasattr(model, "norm")
assert hasattr(model, "rotary_emb")
assert hasattr(model, "bin_centers")
@skip_if_package_missing("transformers")
def test_model_bin_centers_shape():
"""Bin centers buffer has shape (num_value_bins,)."""
model = _make_model()
assert model.bin_centers.shape == (NUM_BINS,)
@skip_if_package_missing("transformers")
def test_model_layer_count():
"""Transformer has num_hidden_layers (6) layers."""
model = _make_model()
assert len(model.layers) == 6
@skip_if_package_missing("transformers")
def test_model_value_head_output_dim():
"""Value head outputs num_value_bins logits."""
model = _make_model()
assert model.value_head.out_features == NUM_BINS
@skip_if_package_missing("transformers")
def test_forward_returns_loss_and_dict():
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
model = _make_model()
batch = _make_batch()
loss, output_dict = model.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert "loss" in output_dict
assert "predicted_value_mean" in output_dict
assert "mc_return_mean" in output_dict
@skip_if_package_missing("transformers")
def test_forward_loss_is_positive():
"""Cross-entropy loss is strictly positive for random weights."""
model = _make_model()
batch = _make_batch()
loss, _ = model.forward(batch)
assert loss.item() > 0
@skip_if_package_missing("transformers")
def test_compute_reward_returns_correct_shape():
"""compute_reward returns [batch_size] tensor of finite float32 values."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=3)
with torch.no_grad():
values = model.compute_reward(batch)
assert values.shape == (3,)
assert values.dtype == torch.float32
assert torch.isfinite(values).all()
@skip_if_package_missing("transformers")
def test_compute_reward_values_in_support_range():
"""Predicted values lie within [value_support_min, value_support_max]."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=8)
with torch.no_grad():
values = model.compute_reward(batch)
assert (values >= -1.0 - 0.01).all()
assert (values <= 0.0 + 0.01).all()
@skip_if_package_missing("transformers")
def test_processor_pipeline_produces_expected_keys():
"""Full preprocessor pipeline produces tokenized text and processed images."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
config = _make_config()
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
raw_batch = {
IMAGE_KEY: torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(raw_batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_ATTENTION_MASK in processed
assert IMAGE_KEY in processed
@skip_if_package_missing("transformers")
def test_gradient_flows_through_value_head():
"""Backprop produces non-zero gradients on the value head."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.value_head.weight.grad is not None
assert not torch.all(model.value_head.weight.grad == 0)
@skip_if_package_missing("transformers")
def test_gradient_flows_through_cls_embedding():
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.cls_embedding.grad is not None
assert not torch.all(model.cls_embedding.grad == 0)
def test_config_requires_visual_feature():
"""validate_features raises if no VISUAL feature is present."""
config = DistributionalVFConfig(init_from_actor_path="")
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
}
with pytest.raises(ValueError, match="VISUAL"):
config.validate_features()
def test_config_passes_with_visual_feature():
"""validate_features succeeds when a VISUAL feature is present."""
config = _make_config()
config.validate_features()
@skip_if_package_missing("transformers")
def test_save_load_pretrained_roundtrip(tmp_path):
"""Saved model can be loaded back with identical weights."""
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
model = _make_model()
model._save_pretrained(tmp_path)
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
orig_sd = model.state_dict()
loaded_sd = loaded.state_dict()
assert set(orig_sd.keys()) == set(loaded_sd.keys())
for key in orig_sd:
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
@skip_if_package_missing("transformers")
def test_image_preprocessor_normalizes_to_minus_one_one():
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 224, 224, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.min() >= -1.0 - 1e-5
assert image.max() <= 1.0 + 1e-5
@skip_if_package_missing("transformers")
def test_image_preprocessor_resizes_with_pad():
"""Image preprocessor resizes non-square images to target resolution."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 480, 640, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.shape[1:3] == (224, 224)
def test_task_prompt_formats_correctly():
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: pick up the cup."
def test_task_prompt_handles_string_input():
"""Task prompt step accepts a plain string (not just a list)."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: open drawer."
def test_task_prompt_raises_on_missing_task():
"""Task prompt step raises ValueError when task key is absent."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {},
}
with pytest.raises(ValueError, match="No task found"):
step(transition)
Generated
+353 -351
View File
File diff suppressed because it is too large Load Diff