Compare commits

..

2 Commits

Author SHA1 Message Date
Khalil Meftah fa3eb9fce3 test(rewards): add unit tests for distributional value function model 2026-06-10 16:07:43 +02:00
Khalil Meftah 500c91ba92 feat(rewards): introduce distributional value function model
- Added a new distributional value function (DistributionalVF) model for RECAP, including its configuration, modeling, and processor components.
- Updated the rewards factory to support the new model type.
- Updated  to include the new model in the dependencies.
2026-06-10 15:24:50 +02:00
13 changed files with 2169 additions and 562 deletions
+2
View File
@@ -214,6 +214,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
recap = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -296,6 +297,7 @@ all = [
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[recap]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
+303 -157
View File
@@ -17,10 +17,12 @@ from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -53,13 +55,12 @@ class VLAJEPAModel(nn.Module):
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Inputs are batched tensors kept on the model device
- images: List[List[Tensor [C, H, W]]] (float [0,1]) — per sample, per view (Qwen messages)
- instructions: List[str]
- videos: Tensor [B, V, T, C, H, W] (float [0,1], world model only)
- actions: Tensor [B, T, action_dim] (optional, training only)
- state: Tensor [B, 1, state_dim] (optional)
- action_is_pad: Tensor [B, T] (optional)
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
@@ -160,123 +161,166 @@ class VLAJEPAModel(nn.Module):
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def _encode_qwen(
self, images: list[list[Tensor]], instructions: list[str], *, need_action_tokens: bool
) -> tuple[Tensor, Tensor, Tensor | None]:
"""Run Qwen and gather the embodied-action (and optionally action) token hidden states."""
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] — multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str — task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=images,
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
input_ids = qwen_inputs["input_ids"]
embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True)
action_idx = None
if need_action_tokens:
action_mask = torch.isin(input_ids, torch.tensor(self.action_token_ids, device=input_ids.device))
action_idx = action_mask.nonzero(as_tuple=True)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
action_tokens = (
last_hidden[action_idx[0], action_idx[1], :].view(b, -1, h)
if action_idx is not None
else None
)
return last_hidden, embodied_action_tokens, action_tokens
def _world_model_loss(self, videos: Tensor, action_tokens: Tensor) -> Tensor:
"""JEPA encode + predictor L1 loss. `videos` is [B, V, T, C, H, W] float in [0, 1]."""
# Match the world model's expected view count: pad with the first view, or trim extras.
num_views = self.config.jepa_tubelet_size
if videos.shape[1] < num_views:
missing = num_views - videos.shape[1]
videos = torch.cat([videos, videos[:, :1].repeat(1, missing, 1, 1, 1, 1)], dim=1)
elif videos.shape[1] > num_views:
videos = videos[:, :num_views]
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
b, v, t_frames, c, h_img, w_img = videos.shape
flat = videos.reshape(b * v, t_frames, c, h_img, w_img)
# Fast (torchvision) video processor on-device, do_rescale=False (frames already in [0, 1]).
video_pixels = self.video_processor(
videos=list(flat),
return_tensors="pt",
device=self.video_encoder.device,
do_rescale=False,
)["pixel_values_videos"] # [B*V, T, C, H, W]
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
return torch.zeros((), device=video_embeddings.device)
# Shift-by-one JEPA split: input_states = positions 0..T-2, gt_states = positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(), action_tokens[:, :expected_actions].float()
)
return F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
def _action_loss(
self,
embodied_action_tokens: Tensor,
actions: Tensor,
state: Tensor | None,
action_is_pad: Tensor | None,
) -> Tensor:
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.float32):
r = self.config.repeated_diffusion_steps
horizon = self.config.chunk_size
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
embodied = embodied_action_tokens.repeat(r, 1, 1)
state_rep = state.to(embodied_action_tokens.dtype).repeat(r, 1, 1) if state is not None else None
pad_rep = action_is_pad[:, -horizon:].repeat(r, 1) if action_is_pad is not None else None
return self.action_model(embodied, actions_target, state_rep, pad_rep)
def forward(
self,
images: list[list[Tensor]],
instructions: list[str],
videos: Tensor | None = None,
actions: Tensor | None = None,
state: Tensor | None = None,
action_is_pad: Tensor | None = None,
) -> dict[str, Tensor]:
"""Native forward: Qwen encode → optional world-model loss → optional action-head loss."""
last_hidden, embodied_action_tokens, action_tokens = self._encode_qwen(
images, instructions, need_action_tokens=self.config.enable_world_model
)
if self.config.enable_world_model:
wm_loss = self._world_model_loss(videos, action_tokens)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
wm_loss = torch.zeros((), device=last_hidden.device)
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
if actions is None:
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
action_loss = self._action_loss(embodied_action_tokens, actions, state, action_is_pad)
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@@ -284,24 +328,58 @@ class VLAJEPAModel(nn.Module):
@torch.no_grad()
def predict_action(
self,
images: list[list[Tensor]],
batch_images: list[list[Image.Image]],
instructions: list[str],
state: Tensor | None = None,
) -> Tensor:
"""Predict an action chunk. `images` is per-sample, per-view float [0,1] [C, H, W] tensors."""
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
images = [
[F.interpolate(img[None], size=(height, width), mode="area")[0] for img in views]
for views in images
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
_, embodied_action_tokens, _ = self._encode_qwen(images, instructions, need_action_tokens=False)
state = state.to(embodied_action_tokens.dtype) if state is not None else None
return self.action_model.predict_action(
embodied_action_tokens.float(), state.float() if state is not None else None
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
@@ -312,9 +390,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the batched tensors
the native model expects (keeping everything on-device), calls the native model, and
converts outputs back to LeRobot format.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
@@ -341,8 +419,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Any]:
"""Convert a LeRobot batch to the model's batched, on-device inputs.
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
@@ -352,25 +431,65 @@ class VLAJEPAPolicy(PreTrainedPolicy):
"task": str | List[str], (optional instruction)
}
Returns the kwargs for `VLAJEPAModel.forward` / `.predict_action` (everything stays
on the batch device; no per-sample shredding): `images` (per-sample, per-view list for
Qwen messages), `instructions`, and the batched `videos` / `actions` / `state` /
`action_is_pad` when present.
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
batch_size = batch[image_keys[0]].shape[0]
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# Current-frame image per view ([B, C, H, W]); regroup per sample for Qwen messages.
frames = []
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
t = batch[key]
if t.ndim == 5: # [B, T, C, H, W] -> current observation (delta=0)
t = t[:, 0]
frames.append(self.model.qwen.to_pixel_values(t))
images = [[frame[b] for frame in frames] for b in range(batch_size)]
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
@@ -379,32 +498,52 @@ class VLAJEPAPolicy(PreTrainedPolicy):
else:
instructions = list(tasks)
inputs: dict[str, Any] = {"images": images, "instructions": instructions}
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# Videos [B, V, T, C, H, W] - only assembled when the world model consumes them.
if self.model.config.enable_world_model:
views = [batch[k].unsqueeze(1) if batch[k].ndim == 4 else batch[k] for k in image_keys]
inputs["videos"] = self.model.qwen.to_pixel_values(torch.stack(views, dim=1))
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
actions = batch.get(ACTION)
if actions is not None:
inputs["actions"] = (actions.unsqueeze(1) if actions.ndim == 2 else actions).float()
if (pad := batch.get("action_is_pad")) is not None:
inputs["action_is_pad"] = pad
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
state = batch.get(OBS_STATE)
if state is not None:
if state.ndim > 2:
state = state[:, -1, :]
inputs["state"] = (state.unsqueeze(1) if state.ndim == 2 else state).float() # [B, 1, dim]
return inputs
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
native_output = self.model.forward(**self._prepare_model_inputs(batch))
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
@@ -422,9 +561,16 @@ class VLAJEPAPolicy(PreTrainedPolicy):
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
inputs = self._prepare_model_inputs(batch)
actions = self.model.predict_action(inputs["images"], inputs["instructions"], inputs.get("state"))
return actions.to(device=self.config.device, dtype=torch.float32)
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
+15 -31
View File
@@ -17,7 +17,9 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
@@ -76,7 +78,7 @@ class Qwen3VLInterface(torch.nn.Module):
def build_inputs(
self,
images: Sequence[Sequence[torch.Tensor]],
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
@@ -92,42 +94,24 @@ class Qwen3VLInterface(torch.nn.Module):
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
# The Qwen image processor is a torchvision-backed fast processor: passing the
# images as GPU tensors (with `device`) keeps the whole vision pipeline on-device
# and avoids a GPU->CPU->GPU roundtrip. The image tensors are forwarded through
# apply_chat_template untouched into Qwen3VLProcessor.__call__.
# do_rescale=False: images already arrive as float in [0, 1] (the dataset decoder
# yields float32/255 and VISUAL normalization is IDENTITY), so we skip the
# processor's /255 rescale instead of round-tripping through uint8.
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={
"padding": True,
"return_tensors": "pt",
"device": self.model.device,
"do_rescale": False,
},
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def to_pixel_values(image_tensor: torch.Tensor) -> torch.Tensor:
"""Prepare an image/video tensor for the fast processors (used with do_rescale=False).
The dataset decoder yields float32 in [0, 1] (channels-first) and VISUAL
normalization is IDENTITY, so the tensor already arrives in [0, 1]; we pass it
through as float and let the processors normalize (no rescale, no uint8
quantization). A single channel is expanded to 3 to match the RGB processors.
Works for any channels-first layout (channel dim is -3): [C, H, W], [B, C, H, W],
[T, C, H, W], [B, V, T, C, H, W], ...
"""
image = image_tensor.detach().float()
if image.shape[-3] == 1:
repeats = [1] * image.ndim
repeats[-3] = 3
image = image.repeat(*repeats)
return image
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
+4
View File
@@ -13,6 +13,9 @@
# limitations under the License.
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig as DistributionalVFConfig,
)
from .factory import (
get_reward_model_class as get_reward_model_class,
make_reward_model as make_reward_model,
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
__all__ = [
# Configuration classes
"DistributionalVFConfig",
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
@@ -0,0 +1,23 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_distributional_value_function import DistributionalVFConfig
from .modeling_distributional_value_function import DistributionalVFRewardModel
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
__all__ = [
"DistributionalVFConfig",
"DistributionalVFRewardModel",
"make_distributional_vf_pre_post_processors",
]
@@ -0,0 +1,108 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weights initialized from a pre-trained PI05 actor checkpoint.
"""
from dataclasses import dataclass, field
from lerobot.configs import FeatureType, NormalizationMode
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
@RewardModelConfig.register_subclass("distributional_value_function")
@dataclass
class DistributionalVFConfig(RewardModelConfig):
"""Configuration for RECAP's distributional value function.
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
(Eq. 1 in the paper).
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
about 670M params and initialize from the PI05 actor checkpoint.
"""
# Backbone
paligemma_variant: str = "gemma_2b"
num_hidden_layers: int = 6
num_vision_layers: int = 13
# Distributional head
num_value_bins: int = 201
value_support_min: float = -1.0
value_support_max: float = 0.0
hl_gauss_sigma_ratio: float = 5.0
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
target_method: str = "hl_gauss"
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
# When False, terminal states use the same target method as non-terminal states.
use_one_hot_terminal: bool = True
# Image
image_resolution: tuple[int, int] = (224, 224)
# Tokenizer
tokenizer_max_length: int = 64
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
# Pass a PI05 checkpoint path or Hub repo_id here.
# After training, load the value function with RewardModel.from_pretrained() instead.
init_from_actor_path: str = ""
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
}
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=3e-4,
weight_decay=1e-4,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
if not self.input_features:
return
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
if not has_image:
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
@@ -0,0 +1,567 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Inputs: single image observation + task text prompt ("Task: {task}.")
Outputs: softmax distribution over value bins; expected value E[V] for inference.
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weight initialization: vision tower, multi-modal projector, token embeddings, and
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.import_utils import _transformers_available, require_package
from .configuration_distributional_value_function import DistributionalVFConfig
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaRMSNorm,
_gated_residual,
_get_pi_gemma_decoder_layer_base,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
PaliGemmaForConditionalGenerationWithPiGemma = None
PiGemmaRMSNorm = None
_gated_residual = None
_get_pi_gemma_decoder_layer_base = None
PALIGEMMA_VOCAB_SIZE = 257152
class DistributionalVFRewardModel(PreTrainedRewardModel):
"""Distributional value function model for RECAP.
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
per-task normalized Monte Carlo returns.
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
causal attention, [CLS] token, and Linear(D, num_bins) value head.
The expected value is E[V] = sum(softmax(logits) * bin_centers).
"""
name = "distributional_value_function"
config_class = DistributionalVFConfig
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
require_package("transformers", extra="recap")
super().__init__(config)
self.config = config
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
# Get base dimensions from the paligemma variant (OpenPI config format)
base_config = get_gemma_config(config.paligemma_variant)
hidden_dim = base_config.width
mlp_dim = base_config.mlp_dim
num_layers = config.num_hidden_layers
# HuggingFace GemmaConfig for transformer layers
gemma_config = CONFIG_MAPPING["gemma"](
head_dim=base_config.head_dim,
hidden_size=hidden_dim,
intermediate_size=mlp_dim,
num_attention_heads=base_config.num_heads,
num_hidden_layers=num_layers,
num_key_value_heads=base_config.num_kv_heads,
vocab_size=PALIGEMMA_VOCAB_SIZE,
hidden_activation="gelu_pytorch_tanh",
)
self.gemma_config = gemma_config
self.hidden_dim = hidden_dim
self.num_value_bins = config.num_value_bins
# Single learned [CLS] token for value prediction
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
# Value projection head: Linear(hidden_dim, num_bins)
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
# Transformer layers (overwritten by _initialize_from_actor on first run)
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
)
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
# PaliGemmaConfig wraps both vision and text configs into a single model
paligemma_config = CONFIG_MAPPING["paligemma"]()
paligemma_config.text_config = gemma_config
paligemma_config.vision_config.image_size = config.image_resolution[0]
paligemma_config.vision_config.intermediate_size = 4304
paligemma_config.vision_config.projection_dim = 2048
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
self.vision_tower = paligemma_full.model.vision_tower
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
self.token_embedding = paligemma_full.model.language_model.embed_tokens
del paligemma_full
# Truncate vision tower to num_vision_layers
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
vision_encoder = self.vision_tower.vision_model.encoder
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
# Bin support: evenly spaced centers from value_support_min to value_support_max
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
self.register_buffer("bin_centers", bin_centers, persistent=False)
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
# Overwrite with pre-trained PI05 actor weights (first training run only)
if config.init_from_actor_path:
self._initialize_from_actor()
def _initialize_from_actor(self) -> None:
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
Called on first training run only (when init_from_actor_path is set).
"""
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
actor_model = actor_policy.model
paligemma_model = actor_model.paligemma_with_expert.paligemma
source_language_model = paligemma_model.model.language_model
# Transformer components
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
num_layers = self.gemma_config.num_hidden_layers
for i in range(num_layers):
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
self.norm.load_state_dict(source_language_model.norm.state_dict())
# Vision tower (truncate source first, then copy)
source_vision_tower = paligemma_model.model.vision_tower
if hasattr(source_vision_tower, "vision_model") and hasattr(
source_vision_tower.vision_model, "encoder"
):
source_encoder = source_vision_tower.vision_model.encoder
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
# Multi-modal projector
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
# Token embedding table
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
del actor_policy
def embed_image(self, image: Tensor) -> Tensor:
"""Embed images using the value function's SigLIP vision tower.
Args:
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
Returns:
[batch_size, num_patches, hidden_dim] projected image features.
"""
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.vision_tower(image, return_dict=True)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
image_features = image_features / (self.hidden_dim**0.5)
if image_features.dtype != out_dtype:
image_features = image_features.to(out_dtype)
return image_features
def embed_text(self, token_ids: Tensor) -> Tensor:
"""Embed text token IDs using the value function's token embedding table.
Args:
token_ids: [batch_size, seq_len] integer token IDs
Returns:
[batch_size, seq_len, hidden_dim] text embeddings
"""
return self.token_embedding(token_ids)
def _get_cls_embedding(self, batch_size: int) -> Tensor:
"""Get [CLS] token embedding expanded to batch size.
Args:
batch_size: number of samples in the batch.
Returns:
[batch_size, 1, hidden_dim] learned [CLS] embedding.
"""
return self.cls_embedding.expand(batch_size, -1, -1)
def forward_value(
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
) -> dict[str, Tensor]:
"""Core forward pass through the distributional value function.
Args:
vision_features: [batch_size, num_patches, hidden_dim]
text_embeddings: [batch_size, seq_len, hidden_dim]
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
Returns:
logits: [batch_size, num_value_bins]
probs: [batch_size, num_value_bins]
value: [batch_size, 1]
"""
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
batch_size = text_embeddings.shape[0]
device = text_embeddings.device
# Build sequence: [vision, text, CLS]
cls_embedding = self._get_cls_embedding(batch_size)
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
# Build causal attention mask
vision_len = vision_features.shape[1]
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
full_seq_len = full_padding_mask.shape[1]
# Causal mask
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
# Combine causal mask with padding mask
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
batch_size, 1, full_seq_len, full_seq_len
)
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
cos, sin = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
norm_output = layer.input_layernorm(hidden_states, cond=None)
if isinstance(norm_output, tuple):
hidden_states_normed, gate = norm_output
else:
hidden_states_normed, gate = norm_output, None
input_shape = hidden_states_normed.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
attention_output, _ = modeling_gemma.eager_attention_forward(
layer.self_attn,
query_states,
key_states,
value_states,
attention_mask,
layer.self_attn.scaling,
)
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
projected_attention = layer.self_attn.o_proj(attention_output)
if gate is not None:
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
else:
projected_attention = hidden_states + projected_attention
after_attention_residual = projected_attention.clone()
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
if isinstance(norm_output, tuple):
mlp_input, gate = norm_output
else:
mlp_input, gate = norm_output, None
mlp_output = layer.mlp(mlp_input)
if gate is not None:
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
else:
hidden_states = after_attention_residual + mlp_output
hidden_states = self.norm(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Extract [CLS] token (last position in the sequence)
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
# Value head: Linear(hidden_dim, num_bins) -> logits
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
value_probs = F.softmax(value_logits, dim=-1)
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
dim=-1, keepdim=True
)
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
"""HL-Gauss soft target distribution.
Places a Gaussian N(target, sigma^2) over the bin support and computes
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
Functions via Classification for Scalable Deep RL", Section 3.1.
arXiv:2403.03950
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
# Bin edges: half a bin-width outside the first/last center
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
self.num_value_bins - 1
)
support_edges = torch.linspace(
self.config.value_support_min - bin_width / 2,
self.config.value_support_max + bin_width / 2,
self.num_value_bins + 1,
device=target_value.device,
dtype=target_value.dtype,
)
# CDF of N(target, sigma^2) evaluated at each edge
cdf_at_edges = 0.5 * (
1.0
+ torch.erf(
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
/ (self.hl_gauss_sigma * math.sqrt(2))
)
) # [batch_size, num_bins + 1]
# Normalize: z = cdf(max_edge) - cdf(min_edge)
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
# Bin probabilities = differences of consecutive CDF values, normalized
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
return bin_probabilities
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
"""Dirac delta (C51) projection: split probability between two nearest bins.
Standard distributional RL projection from Bellemare et al. 2017.
"A Distributional Perspective on Reinforcement Learning"
arXiv:1707.06887
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
target_value = target_value.to(dtype=self.bin_centers.dtype)
bin_width = self.bin_centers[1] - self.bin_centers[0]
normalized_position = (target_value - self.config.value_support_min) / bin_width
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
weight_upper = normalized_position - lower_bin_idx.float()
weight_lower = upper_bin_idx.float() - normalized_position
same_bin = lower_bin_idx == upper_bin_idx
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
batch_size = target_value.shape[0]
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
batch_indices = torch.arange(batch_size, device=target_value.device)
target_distribution[batch_indices, lower_bin_idx] += weight_lower
target_distribution[batch_indices, upper_bin_idx] += weight_upper
return target_distribution
def one_hot_target(self, target_value: Tensor) -> Tensor:
"""One-hot target for terminal states (exact return, no smoothing).
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
nearest_bin_idx = torch.argmin(
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
)
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
def compute_target_distribution(
self,
target_value: Tensor,
is_terminal: Tensor,
method: str = "hl_gauss",
use_one_hot_terminal: bool = True,
) -> Tensor:
"""Compute target distribution using configured method.
Args:
target_value: [batch_size] scalar return targets
is_terminal: [batch_size] boolean terminal flags
method: "hl_gauss" or "dirac_delta"
use_one_hot_terminal: if True, terminal states get one-hot targets
(exact return, no smoothing). If False, all states use the same method.
Returns:
[batch_size, num_value_bins] target probability distribution
"""
if method == "hl_gauss":
base_distribution = self.hl_gauss_target(target_value)
elif method == "dirac_delta":
base_distribution = self.dirac_delta_target(target_value)
else:
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
if not use_one_hot_terminal:
return base_distribution
terminal_distribution = self.one_hot_target(target_value)
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Training forward pass — computes cross-entropy loss against MC return targets.
The batch is expected to be preprocessed by the processor pipeline.
Keys expected in batch:
- observation.images.*: [B, C, H, W] preprocessed images
- observation.language_tokens: [B, seq_len] tokenized task prompt
- observation.language_attention_mask: [B, seq_len] padding mask
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
- is_terminal: [B] boolean terminal flags
Returns:
(loss, output_dict) where loss is scalar cross-entropy
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
# Get first image key from batch
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
mc_return = batch["mc_return"]
is_terminal = batch["is_terminal"]
# Embed observations
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
# Forward through value function transformer
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
value_logits = vf_output["logits"]
predicted_value = vf_output["value"]
# Compute target distribution
target_distribution = self.compute_target_distribution(
mc_return,
is_terminal,
method=self.config.target_method,
use_one_hot_terminal=self.config.use_one_hot_terminal,
)
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
log_probs = F.log_softmax(value_logits, dim=-1)
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
output_dict = {
"loss": loss.item(),
"predicted_value_mean": predicted_value.mean().item(),
"mc_return_mean": mc_return.mean().item(),
}
return loss, output_dict
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute V(s) for a batch of observations. Used for advantage scoring.
Args:
batch: preprocessed batch with images and tokenized text
Returns:
[batch_size] tensor of predicted values V(s)
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
return vf_output["value"].squeeze(-1) # [batch_size]
@@ -0,0 +1,235 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
Training targets (mc_return, is_terminal) are NOT routed through the processor.
They are dataset columns read directly from the batch in the model's forward().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
)
from lerobot.processor.converters import to_tensor
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_distributional_value_function import DistributionalVFConfig
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
@dataclass
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
"""Format the task string for the distributional value function.
The value function receives only visual observations and task text.
Builds prompt: "Task: {task}."
"""
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
tasks = complementary_data.get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
if isinstance(tasks, str):
tasks = [tasks]
full_prompts = []
for task in tasks:
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
full_prompts.append(f"Task: {cleaned_text}.")
new_complementary_data = dict(complementary_data)
new_complementary_data[self.task_key] = full_prompts
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {"task_key": self.task_key}
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
@dataclass
class DistributionalVFImagePreprocessorStep(ProcessorStep):
"""Resize and normalize images for the value function's SigLIP vision tower.
Expects float images in [0, 1].
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
- Scale to [-1, 1] for SigLIP
"""
image_resolution: tuple[int, int] = (224, 224)
image_keys: tuple[str, ...] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
observation = transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
image_keys = self.image_keys or tuple(
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
)
if not image_keys:
raise KeyError(
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
)
new_observation = dict(observation)
for image_key in image_keys:
image = new_observation[image_key]
if not isinstance(image, Tensor):
image = to_tensor(image)
if image.dtype != torch.float32:
image = image.to(torch.float32)
is_channels_first = image.ndim == 4 and image.shape[1] == 3
if is_channels_first:
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != self.image_resolution:
image = resize_with_pad_torch(image, *self.image_resolution)
image = image * 2.0 - 1.0
if is_channels_first:
image = image.permute(0, 3, 1, 2)
new_observation[image_key] = image
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"image_resolution": self.image_resolution,
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
}
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
return tuple(
feature_name
for feature_name, feature in config.input_features.items()
if feature.type == FeatureType.VISUAL
)
def make_distributional_vf_pre_post_processors(
config: DistributionalVFConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create pre/post processors for the distributional value function.
Preprocessor steps:
1. Rename observations (no-op by default)
2. Add a batch dimension
3. Normalize features (images use identity, so they stay in [0, 1])
4. Format task prompt: "Task: {task}."
5. Tokenize with the PaliGemma tokenizer
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
7. Move tensors to the configured device
Training targets (mc_return, is_terminal) are not processed here.
The model reads them directly from the batch in forward().
The postprocessor is a no-op because the value function does not need
action postprocessing.
"""
image_keys = _visual_image_keys(config)
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DistributionalVFPrepareTaskPromptStep(),
TokenizerProcessorStep(
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DistributionalVFImagePreprocessorStep(
image_resolution=config.image_resolution,
image_keys=image_keys or None,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
+19
View File
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
return TOPRewardModel
elif name == "distributional_value_function":
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
elif reward_type == "distributional_value_function":
return DistributionalVFConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, DistributionalVFConfig):
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
return make_distributional_vf_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_reward_model_config(
+9 -11
View File
@@ -8,6 +8,7 @@ from types import SimpleNamespace
import numpy as np
import pytest
import torch
from PIL import Image
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -190,7 +191,7 @@ class _FakeQwenInterface(nn.Module):
def build_inputs(
self,
images: list[list[Tensor]],
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
@@ -213,13 +214,12 @@ class _FakeQwenInterface(nn.Module):
}
@staticmethod
def to_pixel_values(image_tensor: Tensor) -> Tensor:
image = image_tensor.detach().float()
if image.shape[-3] == 1:
repeats = [1] * image.ndim
repeats[-3] = 3
image = image.repeat(*repeats)
return image
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
@@ -242,14 +242,12 @@ class _FakeVideoEncoder(nn.Module):
class _FakeVideoProcessor:
def __call__(self, videos, return_tensors: str, device=None, **kwargs) -> dict[str, Tensor]:
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
if isinstance(videos, list):
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
else:
pixel_values = torch.as_tensor(videos).unsqueeze(0)
if device is not None:
pixel_values = pixel_values.to(device)
return {"pixel_values_videos": pixel_values}
+24 -26
View File
@@ -211,42 +211,40 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
inputs = policy._prepare_model_inputs(make_train_batch())
from PIL import Image
assert set(inputs) >= {"images", "instructions", "videos", "actions", "state"}
# images: per-sample, per-view [C, H, W] float tensors (kept as a list for Qwen messages)
assert len(inputs["images"]) == BATCH_SIZE and len(inputs["images"][0]) == 1
img = inputs["images"][0][0]
assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 and img.ndim == 3
assert len(inputs["instructions"]) == BATCH_SIZE
# videos: batched [B, V, T, C, H, W] float
assert inputs["videos"].ndim == 6 and inputs["videos"].shape[0] == BATCH_SIZE
assert inputs["videos"].dtype == torch.float32
assert inputs["actions"].shape == (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
assert inputs["state"].shape == (BATCH_SIZE, 1, STATE_DIM)
policy = VLAJEPAPolicy(make_config())
examples = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
inputs = policy._prepare_model_inputs(make_inference_batch())
assert "actions" not in inputs and "action_is_pad" not in inputs
assert {"images", "instructions", "state"} <= set(inputs)
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
instructions = policy._prepare_model_inputs(batch)["instructions"]
assert all(isinstance(s, str) and len(s) > 0 for s in instructions)
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert policy._prepare_model_inputs(batch)["instructions"] == ["open the drawer"] * BATCH_SIZE
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
@@ -255,7 +253,7 @@ def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: N
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert "state" not in policy._prepare_model_inputs(batch)
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
# ---------------------------------------------------------------------------
@@ -448,14 +446,14 @@ def test_postprocessor_applied_after_predict_action_chunk(
"""
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
raw_actions = torch.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=torch.float32)
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
cfg = make_config()
cfg.clip_normalized_actions = False
cfg.binarize_gripper_action = False
policy = VLAJEPAPolicy(cfg)
policy.eval()
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.clone())
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
dataset_stats = _make_dataset_stats()
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
@@ -566,9 +564,9 @@ def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_model
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=1))
@@ -589,9 +587,9 @@ def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: No
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=3))
@@ -0,0 +1,518 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RECAP's distributional value function."""
from __future__ import annotations
import pytest
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_IMAGES
from tests.utils import skip_if_package_missing
BATCH_SIZE = 4
NUM_BINS = 201
IMAGE_KEY = f"{OBS_IMAGES}.top"
def _make_config(**overrides) -> DistributionalVFConfig:
defaults = {
"init_from_actor_path": "",
"device": "cpu",
"image_resolution": (224, 224),
}
defaults.update(overrides)
config = DistributionalVFConfig(**defaults)
config.input_features = {
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
}
return config
def _make_model():
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel(_make_config())
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
return {
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
"mc_return": torch.rand(batch_size, device=device) * -1.0,
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
}
def test_config_registered_in_reward_model_registry():
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
known = RewardModelConfig.get_known_choices()
assert "distributional_value_function" in known
def test_factory_returns_correct_class():
"""get_reward_model_class returns DistributionalVFRewardModel."""
from lerobot.rewards.factory import get_reward_model_class
cls = get_reward_model_class("distributional_value_function")
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
assert cls is DistributionalVFRewardModel
def test_make_reward_model_config_factory():
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
from lerobot.rewards.factory import make_reward_model_config
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
assert isinstance(config, DistributionalVFConfig)
assert config.num_value_bins == 101
@skip_if_package_missing("transformers")
def test_hl_gauss_sums_to_one():
"""HL-Gauss target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
dist = model.hl_gauss_target(targets)
assert dist.shape == (4, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_non_negative():
"""HL-Gauss target probabilities are all non-negative."""
model = _make_model()
targets = torch.linspace(-1.0, 0.0, 10)
dist = model.hl_gauss_target(targets)
assert (dist >= 0).all()
@skip_if_package_missing("transformers")
def test_hl_gauss_expected_value_matches():
"""E[V] under HL-Gauss distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.hl_gauss_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_handles_2d_input():
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
dist = model.hl_gauss_target(targets)
assert dist.shape == (2, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_sums_to_one():
"""Dirac delta target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
dist = model.dirac_delta_target(targets)
assert dist.shape == (5, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_at_most_two_nonzero():
"""Dirac delta places probability on at most two adjacent bins."""
model = _make_model()
targets = torch.tensor([-0.7523, -0.0013])
dist = model.dirac_delta_target(targets)
for i in range(2):
assert (dist[i] > 0).sum() <= 2
@skip_if_package_missing("transformers")
def test_dirac_delta_expected_value_matches():
"""E[V] under Dirac delta distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.dirac_delta_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_boundary_values_clamped():
"""Values outside support are clamped to boundary bins."""
model = _make_model()
targets = torch.tensor([-1.5, 0.5])
dist = model.dirac_delta_target(targets)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
assert dist[0, 0] == 1.0
assert dist[1, -1] == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_single_nonzero():
"""One-hot target has exactly one non-zero bin per sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
dist = model.one_hot_target(targets)
assert dist.shape == (4, NUM_BINS)
for i in range(4):
assert (dist[i] > 0).sum() == 1
assert dist[i].sum() == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_nearest_bin():
"""One-hot target activates the bin closest to the target value."""
model = _make_model()
targets = torch.tensor([-0.5])
dist = model.one_hot_target(targets)
hot_idx = dist[0].argmax()
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
@skip_if_package_missing("transformers")
def test_terminal_gets_one_hot():
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
is_terminal = torch.tensor([False, True, False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
)
for i in range(4):
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
assert (dist[1] > 0).sum() == 1
assert (dist[3] > 0).sum() == 1
assert (dist[0] > 0).sum() > 2
assert (dist[2] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_no_terminal_override_when_disabled():
"""When use_one_hot_terminal=False, terminal states use the base method."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3])
is_terminal = torch.tensor([False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
)
assert (dist[1] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_model_has_expected_components():
"""Model scaffold contains all architectural components."""
model = _make_model()
assert hasattr(model, "vision_tower")
assert hasattr(model, "multi_modal_projector")
assert hasattr(model, "token_embedding")
assert hasattr(model, "layers")
assert hasattr(model, "value_head")
assert hasattr(model, "cls_embedding")
assert hasattr(model, "norm")
assert hasattr(model, "rotary_emb")
assert hasattr(model, "bin_centers")
@skip_if_package_missing("transformers")
def test_model_bin_centers_shape():
"""Bin centers buffer has shape (num_value_bins,)."""
model = _make_model()
assert model.bin_centers.shape == (NUM_BINS,)
@skip_if_package_missing("transformers")
def test_model_layer_count():
"""Transformer has num_hidden_layers (6) layers."""
model = _make_model()
assert len(model.layers) == 6
@skip_if_package_missing("transformers")
def test_model_value_head_output_dim():
"""Value head outputs num_value_bins logits."""
model = _make_model()
assert model.value_head.out_features == NUM_BINS
@skip_if_package_missing("transformers")
def test_forward_returns_loss_and_dict():
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
model = _make_model()
batch = _make_batch()
loss, output_dict = model.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert "loss" in output_dict
assert "predicted_value_mean" in output_dict
assert "mc_return_mean" in output_dict
@skip_if_package_missing("transformers")
def test_forward_loss_is_positive():
"""Cross-entropy loss is strictly positive for random weights."""
model = _make_model()
batch = _make_batch()
loss, _ = model.forward(batch)
assert loss.item() > 0
@skip_if_package_missing("transformers")
def test_compute_reward_returns_correct_shape():
"""compute_reward returns [batch_size] tensor of finite float32 values."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=3)
with torch.no_grad():
values = model.compute_reward(batch)
assert values.shape == (3,)
assert values.dtype == torch.float32
assert torch.isfinite(values).all()
@skip_if_package_missing("transformers")
def test_compute_reward_values_in_support_range():
"""Predicted values lie within [value_support_min, value_support_max]."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=8)
with torch.no_grad():
values = model.compute_reward(batch)
assert (values >= -1.0 - 0.01).all()
assert (values <= 0.0 + 0.01).all()
@skip_if_package_missing("transformers")
def test_processor_pipeline_produces_expected_keys():
"""Full preprocessor pipeline produces tokenized text and processed images."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
config = _make_config()
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
raw_batch = {
IMAGE_KEY: torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(raw_batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_ATTENTION_MASK in processed
assert IMAGE_KEY in processed
@skip_if_package_missing("transformers")
def test_gradient_flows_through_value_head():
"""Backprop produces non-zero gradients on the value head."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.value_head.weight.grad is not None
assert not torch.all(model.value_head.weight.grad == 0)
@skip_if_package_missing("transformers")
def test_gradient_flows_through_cls_embedding():
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.cls_embedding.grad is not None
assert not torch.all(model.cls_embedding.grad == 0)
def test_config_requires_visual_feature():
"""validate_features raises if no VISUAL feature is present."""
config = DistributionalVFConfig(init_from_actor_path="")
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
}
with pytest.raises(ValueError, match="VISUAL"):
config.validate_features()
def test_config_passes_with_visual_feature():
"""validate_features succeeds when a VISUAL feature is present."""
config = _make_config()
config.validate_features()
@skip_if_package_missing("transformers")
def test_save_load_pretrained_roundtrip(tmp_path):
"""Saved model can be loaded back with identical weights."""
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
model = _make_model()
model._save_pretrained(tmp_path)
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
orig_sd = model.state_dict()
loaded_sd = loaded.state_dict()
assert set(orig_sd.keys()) == set(loaded_sd.keys())
for key in orig_sd:
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
@skip_if_package_missing("transformers")
def test_image_preprocessor_normalizes_to_minus_one_one():
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 224, 224, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.min() >= -1.0 - 1e-5
assert image.max() <= 1.0 + 1e-5
@skip_if_package_missing("transformers")
def test_image_preprocessor_resizes_with_pad():
"""Image preprocessor resizes non-square images to target resolution."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 480, 640, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.shape[1:3] == (224, 224)
def test_task_prompt_formats_correctly():
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: pick up the cup."
def test_task_prompt_handles_string_input():
"""Task prompt step accepts a plain string (not just a list)."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: open drawer."
def test_task_prompt_raises_on_missing_task():
"""Task prompt step raises ValueError when task key is absent."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {},
}
with pytest.raises(ValueError, match="No task found"):
step(transition)
Generated
+342 -337
View File
File diff suppressed because it is too large Load Diff