mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-12 05:59:53 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 87242cfced | |||
| 1edc83a0ef | |||
| 6fbcf67249 | |||
| 41166b39fb | |||
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 |
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
@@ -101,11 +101,13 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -133,6 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -140,7 +143,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
+15
-11
@@ -115,8 +115,8 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
"wandb>=0.24.0,<0.28.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
]
|
||||
hardware = [
|
||||
"lerobot[pynput-dep]",
|
||||
@@ -142,7 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
@@ -177,7 +178,12 @@ unitree_g1 = [
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||
reachy2 = [
|
||||
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||
"grpcio<=1.73.1",
|
||||
"protobuf<=6.32.0",
|
||||
]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
@@ -199,7 +205,7 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -214,10 +220,9 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
recap = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -225,16 +230,16 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -297,7 +302,6 @@ all = [
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[recap]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
@@ -13,9 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig as DistributionalVFConfig,
|
||||
)
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
@@ -29,7 +26,6 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"DistributionalVFConfig",
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"DistributionalVFConfig",
|
||||
"DistributionalVFRewardModel",
|
||||
"make_distributional_vf_pre_post_processors",
|
||||
]
|
||||
-108
@@ -1,108 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||
@dataclass
|
||||
class DistributionalVFConfig(RewardModelConfig):
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||
(Eq. 1 in the paper).
|
||||
|
||||
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||
about 670M params and initialize from the PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
# Backbone
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
num_hidden_layers: int = 6
|
||||
num_vision_layers: int = 13
|
||||
|
||||
# Distributional head
|
||||
num_value_bins: int = 201
|
||||
value_support_min: float = -1.0
|
||||
value_support_max: float = 0.0
|
||||
hl_gauss_sigma_ratio: float = 5.0
|
||||
|
||||
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||
target_method: str = "hl_gauss"
|
||||
|
||||
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||
# When False, terminal states use the same target method as non-terminal states.
|
||||
use_one_hot_terminal: bool = True
|
||||
|
||||
# Image
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 64
|
||||
|
||||
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||
init_from_actor_path: str = ""
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=1e-4,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.input_features:
|
||||
return
|
||||
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||
if not has_image:
|
||||
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||
-567
@@ -1,567 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modeling for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
|
||||
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaRMSNorm,
|
||||
_gated_residual,
|
||||
_get_pi_gemma_decoder_layer_base,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
PiGemmaRMSNorm = None
|
||||
_gated_residual = None
|
||||
_get_pi_gemma_decoder_layer_base = None
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257152
|
||||
|
||||
|
||||
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||
"""Distributional value function model for RECAP.
|
||||
|
||||
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||
per-task normalized Monte Carlo returns.
|
||||
|
||||
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||
"""
|
||||
|
||||
name = "distributional_value_function"
|
||||
config_class = DistributionalVFConfig
|
||||
|
||||
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||
require_package("transformers", extra="recap")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||
|
||||
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||
base_config = get_gemma_config(config.paligemma_variant)
|
||||
hidden_dim = base_config.width
|
||||
mlp_dim = base_config.mlp_dim
|
||||
num_layers = config.num_hidden_layers
|
||||
|
||||
# HuggingFace GemmaConfig for transformer layers
|
||||
gemma_config = CONFIG_MAPPING["gemma"](
|
||||
head_dim=base_config.head_dim,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=mlp_dim,
|
||||
num_attention_heads=base_config.num_heads,
|
||||
num_hidden_layers=num_layers,
|
||||
num_key_value_heads=base_config.num_kv_heads,
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
)
|
||||
self.gemma_config = gemma_config
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_value_bins = config.num_value_bins
|
||||
|
||||
# Single learned [CLS] token for value prediction
|
||||
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||
|
||||
# Value projection head: Linear(hidden_dim, num_bins)
|
||||
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||
|
||||
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||
|
||||
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||
paligemma_config.text_config = gemma_config
|
||||
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||
paligemma_config.vision_config.intermediate_size = 4304
|
||||
paligemma_config.vision_config.projection_dim = 2048
|
||||
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||
|
||||
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||
self.vision_tower = paligemma_full.model.vision_tower
|
||||
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||
del paligemma_full
|
||||
|
||||
# Truncate vision tower to num_vision_layers
|
||||
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||
vision_encoder = self.vision_tower.vision_model.encoder
|
||||
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||
|
||||
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||
|
||||
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||
if config.init_from_actor_path:
|
||||
self._initialize_from_actor()
|
||||
|
||||
def _initialize_from_actor(self) -> None:
|
||||
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||
|
||||
Called on first training run only (when init_from_actor_path is set).
|
||||
"""
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||
actor_model = actor_policy.model
|
||||
|
||||
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||
source_language_model = paligemma_model.model.language_model
|
||||
|
||||
# Transformer components
|
||||
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||
num_layers = self.gemma_config.num_hidden_layers
|
||||
for i in range(num_layers):
|
||||
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||
|
||||
# Vision tower (truncate source first, then copy)
|
||||
source_vision_tower = paligemma_model.model.vision_tower
|
||||
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||
source_vision_tower.vision_model, "encoder"
|
||||
):
|
||||
source_encoder = source_vision_tower.vision_model.encoder
|
||||
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||
|
||||
# Multi-modal projector
|
||||
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||
|
||||
# Token embedding table
|
||||
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||
|
||||
del actor_policy
|
||||
|
||||
def embed_image(self, image: Tensor) -> Tensor:
|
||||
"""Embed images using the value function's SigLIP vision tower.
|
||||
|
||||
Args:
|
||||
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||
|
||||
Returns:
|
||||
[batch_size, num_patches, hidden_dim] projected image features.
|
||||
"""
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
image_outputs = self.vision_tower(image, return_dict=True)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
image_features = image_features / (self.hidden_dim**0.5)
|
||||
|
||||
if image_features.dtype != out_dtype:
|
||||
image_features = image_features.to(out_dtype)
|
||||
return image_features
|
||||
|
||||
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||
"""Embed text token IDs using the value function's token embedding table.
|
||||
|
||||
Args:
|
||||
token_ids: [batch_size, seq_len] integer token IDs
|
||||
|
||||
Returns:
|
||||
[batch_size, seq_len, hidden_dim] text embeddings
|
||||
"""
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||
"""Get [CLS] token embedding expanded to batch size.
|
||||
|
||||
Args:
|
||||
batch_size: number of samples in the batch.
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||
"""
|
||||
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||
|
||||
def forward_value(
|
||||
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||
) -> dict[str, Tensor]:
|
||||
"""Core forward pass through the distributional value function.
|
||||
|
||||
Args:
|
||||
vision_features: [batch_size, num_patches, hidden_dim]
|
||||
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, num_value_bins]
|
||||
probs: [batch_size, num_value_bins]
|
||||
value: [batch_size, 1]
|
||||
"""
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
device = text_embeddings.device
|
||||
|
||||
# Build sequence: [vision, text, CLS]
|
||||
cls_embedding = self._get_cls_embedding(batch_size)
|
||||
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||
|
||||
# Build causal attention mask
|
||||
vision_len = vision_features.shape[1]
|
||||
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||
|
||||
full_seq_len = full_padding_mask.shape[1]
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||
# Combine causal mask with padding mask
|
||||
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||
batch_size, 1, full_seq_len, full_seq_len
|
||||
)
|
||||
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
hidden_states_normed, gate = norm_output
|
||||
else:
|
||||
hidden_states_normed, gate = norm_output, None
|
||||
|
||||
input_shape = hidden_states_normed.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||
layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
layer.self_attn.scaling,
|
||||
)
|
||||
|
||||
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||
|
||||
if gate is not None:
|
||||
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||
else:
|
||||
projected_attention = hidden_states + projected_attention
|
||||
|
||||
after_attention_residual = projected_attention.clone()
|
||||
|
||||
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
mlp_input, gate = norm_output
|
||||
else:
|
||||
mlp_input, gate = norm_output, None
|
||||
|
||||
mlp_output = layer.mlp(mlp_input)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||
else:
|
||||
hidden_states = after_attention_residual + mlp_output
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
# Extract [CLS] token (last position in the sequence)
|
||||
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||
|
||||
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||
value_probs = F.softmax(value_logits, dim=-1)
|
||||
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||
|
||||
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||
"""HL-Gauss soft target distribution.
|
||||
|
||||
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||
|
||||
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||
arXiv:2403.03950
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
# Bin edges: half a bin-width outside the first/last center
|
||||
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||
self.num_value_bins - 1
|
||||
)
|
||||
support_edges = torch.linspace(
|
||||
self.config.value_support_min - bin_width / 2,
|
||||
self.config.value_support_max + bin_width / 2,
|
||||
self.num_value_bins + 1,
|
||||
device=target_value.device,
|
||||
dtype=target_value.dtype,
|
||||
)
|
||||
|
||||
# CDF of N(target, sigma^2) evaluated at each edge
|
||||
cdf_at_edges = 0.5 * (
|
||||
1.0
|
||||
+ torch.erf(
|
||||
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||
)
|
||||
) # [batch_size, num_bins + 1]
|
||||
|
||||
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||
|
||||
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||
|
||||
return bin_probabilities
|
||||
|
||||
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||
|
||||
Standard distributional RL projection from Bellemare et al. 2017.
|
||||
"A Distributional Perspective on Reinforcement Learning"
|
||||
arXiv:1707.06887
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||
|
||||
weight_upper = normalized_position - lower_bin_idx.float()
|
||||
weight_lower = upper_bin_idx.float() - normalized_position
|
||||
|
||||
same_bin = lower_bin_idx == upper_bin_idx
|
||||
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||
|
||||
batch_size = target_value.shape[0]
|
||||
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||
|
||||
return target_distribution
|
||||
|
||||
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||
"""One-hot target for terminal states (exact return, no smoothing).
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
nearest_bin_idx = torch.argmin(
|
||||
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||
)
|
||||
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||
|
||||
def compute_target_distribution(
|
||||
self,
|
||||
target_value: Tensor,
|
||||
is_terminal: Tensor,
|
||||
method: str = "hl_gauss",
|
||||
use_one_hot_terminal: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute target distribution using configured method.
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] scalar return targets
|
||||
is_terminal: [batch_size] boolean terminal flags
|
||||
method: "hl_gauss" or "dirac_delta"
|
||||
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||
(exact return, no smoothing). If False, all states use the same method.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution
|
||||
"""
|
||||
if method == "hl_gauss":
|
||||
base_distribution = self.hl_gauss_target(target_value)
|
||||
elif method == "dirac_delta":
|
||||
base_distribution = self.dirac_delta_target(target_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||
|
||||
if not use_one_hot_terminal:
|
||||
return base_distribution
|
||||
|
||||
terminal_distribution = self.one_hot_target(target_value)
|
||||
|
||||
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||
|
||||
The batch is expected to be preprocessed by the processor pipeline.
|
||||
Keys expected in batch:
|
||||
- observation.images.*: [B, C, H, W] preprocessed images
|
||||
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||
- is_terminal: [B] boolean terminal flags
|
||||
|
||||
Returns:
|
||||
(loss, output_dict) where loss is scalar cross-entropy
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
# Get first image key from batch
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
mc_return = batch["mc_return"]
|
||||
is_terminal = batch["is_terminal"]
|
||||
|
||||
# Embed observations
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
# Forward through value function transformer
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
value_logits = vf_output["logits"]
|
||||
predicted_value = vf_output["value"]
|
||||
|
||||
# Compute target distribution
|
||||
target_distribution = self.compute_target_distribution(
|
||||
mc_return,
|
||||
is_terminal,
|
||||
method=self.config.target_method,
|
||||
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||
)
|
||||
|
||||
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||
|
||||
output_dict = {
|
||||
"loss": loss.item(),
|
||||
"predicted_value_mean": predicted_value.mean().item(),
|
||||
"mc_return_mean": mc_return.mean().item(),
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||
|
||||
Args:
|
||||
batch: preprocessed batch with images and tokenized text
|
||||
|
||||
Returns:
|
||||
[batch_size] tensor of predicted values V(s)
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||
-235
@@ -1,235 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processor for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||
|
||||
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||
They are dataset columns read directly from the batch in the model's forward().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||
@dataclass
|
||||
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||
"""Format the task string for the distributional value function.
|
||||
|
||||
The value function receives only visual observations and task text.
|
||||
Builds prompt: "Task: {task}."
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
tasks = complementary_data.get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
full_prompts = []
|
||||
for task in tasks:
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
full_prompts.append(f"Task: {cleaned_text}.")
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||
@dataclass
|
||||
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||
|
||||
Expects float images in [0, 1].
|
||||
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||
- Scale to [-1, 1] for SigLIP
|
||||
"""
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
image_keys: tuple[str, ...] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||
|
||||
image_keys = self.image_keys or tuple(
|
||||
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError(
|
||||
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||
)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for image_key in image_keys:
|
||||
image = new_observation[image_key]
|
||||
if not isinstance(image, Tensor):
|
||||
image = to_tensor(image)
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != self.image_resolution:
|
||||
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
|
||||
new_observation[image_key] = image
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_resolution": self.image_resolution,
|
||||
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||
return tuple(
|
||||
feature_name
|
||||
for feature_name, feature in config.input_features.items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
)
|
||||
|
||||
|
||||
def make_distributional_vf_pre_post_processors(
|
||||
config: DistributionalVFConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre/post processors for the distributional value function.
|
||||
|
||||
Preprocessor steps:
|
||||
1. Rename observations (no-op by default)
|
||||
2. Add a batch dimension
|
||||
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||
4. Format task prompt: "Task: {task}."
|
||||
5. Tokenize with the PaliGemma tokenizer
|
||||
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||
7. Move tensors to the configured device
|
||||
|
||||
Training targets (mc_return, is_terminal) are not processed here.
|
||||
The model reads them directly from the batch in forward().
|
||||
|
||||
The postprocessor is a no-op because the value function does not need
|
||||
action postprocessing.
|
||||
"""
|
||||
image_keys = _visual_image_keys(config)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DistributionalVFPrepareTaskPromptStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DistributionalVFImagePreprocessorStep(
|
||||
image_resolution=config.image_resolution,
|
||||
image_keys=image_keys or None,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -24,7 +24,6 @@ from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
@@ -64,12 +63,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
elif name == "distributional_value_function":
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -103,8 +96,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
elif reward_type == "distributional_value_function":
|
||||
return DistributionalVFConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -200,16 +191,6 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
|
||||
return make_distributional_vf_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -99,6 +99,9 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
@@ -158,6 +161,8 @@ def update_policy(
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
if torch.cuda.is_available():
|
||||
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -232,15 +237,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -386,12 +394,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
@@ -424,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
|
||||
# is actually optimizing. grad_norm and lr are already identical on every rank (post
|
||||
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
|
||||
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
|
||||
# true straggler instead of rank 0's view.
|
||||
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
|
||||
# Derived from the post-reduce max step time; set once per log window on the main rank.
|
||||
"samples_per_s": AverageMeter("smp/s", ":.0f"),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
# max() because headroom is gated by the worst-case rank.
|
||||
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
|
||||
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
@@ -481,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
# Collective reduce must run on every rank, before the main-process gate below.
|
||||
train_tracker.reduce_across_ranks()
|
||||
if is_main_process:
|
||||
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
|
||||
# reflects the slowest rank — which is what actually gates the next iteration.
|
||||
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
|
||||
if step_time > 0:
|
||||
train_tracker.samples_per_s = effective_batch_size / step_time
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
|
||||
@@ -13,21 +13,39 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import format_big_number
|
||||
|
||||
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
||||
|
||||
Args:
|
||||
name: Display name of the metric.
|
||||
fmt: Format string used when rendering the metric.
|
||||
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
|
||||
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
|
||||
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
|
||||
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
|
||||
if reduction not in _VALID_REDUCTIONS:
|
||||
raise ValueError(
|
||||
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
|
||||
)
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reduction = reduction
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -138,6 +156,37 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def reduce_across_ranks(self) -> None:
|
||||
"""
|
||||
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
|
||||
across all distributed processes (in-place).
|
||||
|
||||
This is a collective operation and MUST be invoked on every rank — typically just before
|
||||
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
|
||||
reported by the main process only reflect rank 0; for bottleneck-style timings
|
||||
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
|
||||
"""
|
||||
if self.accelerator is None or self.accelerator.num_processes <= 1:
|
||||
return
|
||||
|
||||
buckets: dict[str, list[str]] = defaultdict(list)
|
||||
for name, meter in self.metrics.items():
|
||||
if meter.reduction != "none":
|
||||
buckets[meter.reduction].append(name)
|
||||
if not buckets:
|
||||
return
|
||||
|
||||
device = self.accelerator.device
|
||||
for reduction, names in buckets.items():
|
||||
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
|
||||
reduced = self.accelerator.reduce(tensor, reduction=reduction)
|
||||
for name, value in zip(names, reduced.tolist(), strict=True):
|
||||
meter = self.metrics[name]
|
||||
# Preserve avg == sum / count so a later .update() on this meter accumulates
|
||||
# against the cluster view, not the stale per-rank history.
|
||||
meter.avg = value
|
||||
meter.sum = value * meter.count
|
||||
|
||||
def __str__(self) -> str:
|
||||
display_list = [
|
||||
f"step:{format_big_number(self.steps)}",
|
||||
|
||||
@@ -114,6 +114,30 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
@@ -1,518 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP's distributional value function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_BINS = 201
|
||||
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||
|
||||
|
||||
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||
defaults = {
|
||||
"init_from_actor_path": "",
|
||||
"device": "cpu",
|
||||
"image_resolution": (224, 224),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
config = DistributionalVFConfig(**defaults)
|
||||
config.input_features = {
|
||||
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_model():
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel(_make_config())
|
||||
|
||||
|
||||
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
return {
|
||||
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
def test_config_registered_in_reward_model_registry():
|
||||
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "distributional_value_function" in known
|
||||
|
||||
|
||||
def test_factory_returns_correct_class():
|
||||
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("distributional_value_function")
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
assert cls is DistributionalVFRewardModel
|
||||
|
||||
|
||||
def test_make_reward_model_config_factory():
|
||||
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||
from lerobot.rewards.factory import make_reward_model_config
|
||||
|
||||
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||
assert isinstance(config, DistributionalVFConfig)
|
||||
assert config.num_value_bins == 101
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_sums_to_one():
|
||||
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_non_negative():
|
||||
"""HL-Gauss target probabilities are all non-negative."""
|
||||
model = _make_model()
|
||||
targets = torch.linspace(-1.0, 0.0, 10)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert (dist >= 0).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_expected_value_matches():
|
||||
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_handles_2d_input():
|
||||
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (2, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_sums_to_one():
|
||||
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
assert dist.shape == (5, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_at_most_two_nonzero():
|
||||
"""Dirac delta places probability on at most two adjacent bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.7523, -0.0013])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
for i in range(2):
|
||||
assert (dist[i] > 0).sum() <= 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_expected_value_matches():
|
||||
"""E[V] under Dirac delta distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_boundary_values_clamped():
|
||||
"""Values outside support are clamped to boundary bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-1.5, 0.5])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||
assert dist[0, 0] == 1.0
|
||||
assert dist[1, -1] == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_single_nonzero():
|
||||
"""One-hot target has exactly one non-zero bin per sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
for i in range(4):
|
||||
assert (dist[i] > 0).sum() == 1
|
||||
assert dist[i].sum() == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_nearest_bin():
|
||||
"""One-hot target activates the bin closest to the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
hot_idx = dist[0].argmax()
|
||||
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_terminal_gets_one_hot():
|
||||
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||
is_terminal = torch.tensor([False, True, False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||
assert (dist[1] > 0).sum() == 1
|
||||
assert (dist[3] > 0).sum() == 1
|
||||
assert (dist[0] > 0).sum() > 2
|
||||
assert (dist[2] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_no_terminal_override_when_disabled():
|
||||
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3])
|
||||
is_terminal = torch.tensor([False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||
)
|
||||
|
||||
assert (dist[1] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_has_expected_components():
|
||||
"""Model scaffold contains all architectural components."""
|
||||
model = _make_model()
|
||||
|
||||
assert hasattr(model, "vision_tower")
|
||||
assert hasattr(model, "multi_modal_projector")
|
||||
assert hasattr(model, "token_embedding")
|
||||
assert hasattr(model, "layers")
|
||||
assert hasattr(model, "value_head")
|
||||
assert hasattr(model, "cls_embedding")
|
||||
assert hasattr(model, "norm")
|
||||
assert hasattr(model, "rotary_emb")
|
||||
assert hasattr(model, "bin_centers")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_bin_centers_shape():
|
||||
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||
model = _make_model()
|
||||
assert model.bin_centers.shape == (NUM_BINS,)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_layer_count():
|
||||
"""Transformer has num_hidden_layers (6) layers."""
|
||||
model = _make_model()
|
||||
assert len(model.layers) == 6
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_value_head_output_dim():
|
||||
"""Value head outputs num_value_bins logits."""
|
||||
model = _make_model()
|
||||
assert model.value_head.out_features == NUM_BINS
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_returns_loss_and_dict():
|
||||
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, output_dict = model.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert "loss" in output_dict
|
||||
assert "predicted_value_mean" in output_dict
|
||||
assert "mc_return_mean" in output_dict
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_loss_is_positive():
|
||||
"""Cross-entropy loss is strictly positive for random weights."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_returns_correct_shape():
|
||||
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=3)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert values.shape == (3,)
|
||||
assert values.dtype == torch.float32
|
||||
assert torch.isfinite(values).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_values_in_support_range():
|
||||
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=8)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert (values >= -1.0 - 0.01).all()
|
||||
assert (values <= 0.0 + 0.01).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_processor_pipeline_produces_expected_keys():
|
||||
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
config = _make_config()
|
||||
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||
|
||||
raw_batch = {
|
||||
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
|
||||
processed = preprocessor(raw_batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert IMAGE_KEY in processed
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_value_head():
|
||||
"""Backprop produces non-zero gradients on the value head."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.value_head.weight.grad is not None
|
||||
assert not torch.all(model.value_head.weight.grad == 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_cls_embedding():
|
||||
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.cls_embedding.grad is not None
|
||||
assert not torch.all(model.cls_embedding.grad == 0)
|
||||
|
||||
|
||||
def test_config_requires_visual_feature():
|
||||
"""validate_features raises if no VISUAL feature is present."""
|
||||
config = DistributionalVFConfig(init_from_actor_path="")
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="VISUAL"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_config_passes_with_visual_feature():
|
||||
"""validate_features succeeds when a VISUAL feature is present."""
|
||||
config = _make_config()
|
||||
config.validate_features()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||
"""Saved model can be loaded back with identical weights."""
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
model = _make_model()
|
||||
model._save_pretrained(tmp_path)
|
||||
|
||||
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
orig_sd = model.state_dict()
|
||||
loaded_sd = loaded.state_dict()
|
||||
|
||||
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||
for key in orig_sd:
|
||||
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.min() >= -1.0 - 1e-5
|
||||
assert image.max() <= 1.0 + 1e-5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_resizes_with_pad():
|
||||
"""Image preprocessor resizes non-square images to target resolution."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.shape[1:3] == (224, 224)
|
||||
|
||||
|
||||
def test_task_prompt_formats_correctly():
|
||||
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: pick up the cup."
|
||||
|
||||
|
||||
def test_task_prompt_handles_string_input():
|
||||
"""Task prompt step accepts a plain string (not just a list)."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: open drawer."
|
||||
|
||||
|
||||
def test_task_prompt_raises_on_missing_task():
|
||||
"""Task prompt step raises ValueError when task key is absent."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No task found"):
|
||||
step(transition)
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
@@ -25,8 +26,16 @@ def mock_metrics():
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int):
|
||||
def __init__(self, num_processes: int, reduce_fn=None):
|
||||
self.num_processes = num_processes
|
||||
self.device = torch.device("cpu")
|
||||
self._reduce_fn = reduce_fn
|
||||
|
||||
def reduce(self, tensor, reduction="mean"):
|
||||
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
|
||||
if self._reduce_fn is not None:
|
||||
return self._reduce_fn(tensor, reduction)
|
||||
return tensor
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
@@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker.reset_averages()
|
||||
assert tracker.loss.avg == 0.0
|
||||
assert tracker.accuracy.avg == 0.0
|
||||
|
||||
|
||||
def test_average_meter_invalid_reduction():
|
||||
with pytest.raises(ValueError):
|
||||
AverageMeter("loss", reduction="median")
|
||||
|
||||
|
||||
def test_average_meter_reduction_stored():
|
||||
meter = AverageMeter("updt_s", reduction="max")
|
||||
assert meter.reduction == "max"
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op without accelerator
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_single_process():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=1),
|
||||
)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op when world size is 1
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
|
||||
captured = {}
|
||||
|
||||
def fake_reduce(tensor, reduction):
|
||||
captured["reduction"] = reduction
|
||||
captured["values"] = tensor.clone()
|
||||
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
|
||||
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
metrics = {
|
||||
"loss": AverageMeter("loss"), # reduction="none" -> not touched
|
||||
"update_s": AverageMeter("update_s", reduction="max"),
|
||||
}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
|
||||
)
|
||||
tracker.loss = 1.0
|
||||
tracker.update_s = 0.4
|
||||
tracker.reduce_across_ranks()
|
||||
|
||||
assert captured["reduction"] == "max"
|
||||
assert torch.allclose(captured["values"], torch.tensor([0.4]))
|
||||
assert tracker.update_s.avg == pytest.approx(0.9)
|
||||
# Metrics without a reduction stay untouched.
|
||||
assert tracker.loss.avg == 1.0
|
||||
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
|
||||
# accumulate against the cluster view rather than the stale per-rank sum.
|
||||
meter = tracker.update_s
|
||||
assert meter.sum / meter.count == pytest.approx(meter.avg)
|
||||
|
||||
Reference in New Issue
Block a user