mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-11 21:49:47 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fa3eb9fce3 | |||
| 500c91ba92 | |||
| 49755a3d9e |
@@ -22,10 +22,6 @@ outputs
|
||||
rl
|
||||
media
|
||||
|
||||
# Local virtualenvs (the image provides its own)
|
||||
.venv
|
||||
venv
|
||||
|
||||
|
||||
# Logging
|
||||
logs
|
||||
|
||||
@@ -67,8 +67,6 @@
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: lingbot_va
|
||||
title: LingBot-VA
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
# LingBot-VA
|
||||
|
||||
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
|
||||
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
|
||||
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
|
||||
integration wires LingBot-VA into the standard training, evaluation and processor
|
||||
interfaces.
|
||||
|
||||
## Model Overview
|
||||
|
||||
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
|
||||
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
|
||||
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
|
||||
text conditioning.
|
||||
|
||||
| Component | Class | Role |
|
||||
| ------------------------ | ----------------------- | ----------------------------------------------------------- |
|
||||
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer. |
|
||||
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
|
||||
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
|
||||
|
||||
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
|
||||
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
|
||||
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
|
||||
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=lingbot_va` configuration through LeRobot.
|
||||
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
|
||||
- Autoregressive dual-stream inference behind the standard `select_action` interface
|
||||
(single-environment eval, `--eval.batch_size=1`).
|
||||
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
|
||||
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
|
||||
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the LingBot-VA extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[lingbot_va]"
|
||||
```
|
||||
|
||||
## Checkpoints
|
||||
|
||||
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
|
||||
|
||||
| Variant | LeRobot checkpoint |
|
||||
| ---------------------- | -------------------------------- |
|
||||
| LIBERO-Long post-train | `lerobot/lingbot_va_libero_long` |
|
||||
| RoboTwin post-train | `lerobot/lingbot_va_robotwin` |
|
||||
| Pretrained base | `lerobot/lingbot_va_base` |
|
||||
|
||||
Only the trainable ~5B transformer is stored in the LeRobot
|
||||
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are pulled from
|
||||
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
|
||||
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
|
||||
transformer + VAE fit on a single 24–32 GB GPU.
|
||||
|
||||
## Evaluation (LIBERO)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/lingbot_va_libero_long \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero --env.task=libero_10 \
|
||||
--env.observation_height=128 --env.observation_width=128 \
|
||||
--eval.n_episodes=50 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_libero
|
||||
```
|
||||
|
||||
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
|
||||
single-environment eval; use `--eval.batch_size=1`.
|
||||
|
||||
## Evaluation (RoboTwin)
|
||||
|
||||
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack. You can use the benchmark Docker image
|
||||
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
|
||||
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
|
||||
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
|
||||
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
|
||||
executed via CuRobo IK.
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/lingbot_va_robotwin \
|
||||
--policy.device=cuda \
|
||||
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
|
||||
--eval.n_episodes=10 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_robotwin
|
||||
```
|
||||
|
||||
### Saving predicted (imagined) videos
|
||||
|
||||
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
|
||||
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
|
||||
The same flag works for the periodic eval during `lerobot-train`.
|
||||
|
||||
## Training / fine-tuning
|
||||
|
||||
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
|
||||
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
|
||||
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
|
||||
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
|
||||
with a linear-warmup-then-constant schedule (matching upstream).
|
||||
|
||||
Requirements:
|
||||
|
||||
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
|
||||
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
|
||||
- The full 5B DiT does not fit a single 24–32 GB GPU under AdamW; fine-tune with **LoRA**
|
||||
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
|
||||
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/lingbot_va_libero_long --policy.attn_mode=flex \
|
||||
--policy.use_peft=true \
|
||||
--dataset.repo_id=<your LeRobot-format dataset> \
|
||||
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
|
||||
```
|
||||
|
||||
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
|
||||
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
|
||||
|
||||
## Data format (action channels & camera order)
|
||||
|
||||
LingBot-VA is an **end-effector (Cartesian) pose** policy, it predicts EEF poses + gripper, not
|
||||
joint positions. Actions live in a fixed multi-embodiment **30-dim** layout; map your robot's
|
||||
action dimensions into these channels and pad the rest with `0` (`used_action_channel_ids` selects
|
||||
the channels a given checkpoint actually uses):
|
||||
|
||||
| channels | meaning |
|
||||
| -------- | ----------------------------------------------------- |
|
||||
| 0–6 | Left-arm end-effector pose |
|
||||
| 7–13 | Right-arm end-effector pose |
|
||||
| 14–20 | Left-arm joints (unused by the released checkpoints) |
|
||||
| 21–27 | Right-arm joints (unused by the released checkpoints) |
|
||||
| 28 | Left gripper |
|
||||
| 29 | Right gripper |
|
||||
|
||||
- **LIBERO** uses channels `0–6`: a 6-DoF EEF delta (xyz + rotation) + gripper (single arm).
|
||||
- **RoboTwin** uses channels `[0–6, 28, 7–13, 29]`: left EEF (xyz + quaternion) + left gripper +
|
||||
right EEF + right gripper (16 dims). The env converts these poses to joint trajectories via
|
||||
CuRobo IK — joints are never predicted.
|
||||
|
||||
Joint-space datasets (or a different EEF convention) must be remapped into this schema before
|
||||
fine-tuning these checkpoints.
|
||||
|
||||
**Camera order is fixed and order-sensitive**, per-camera latents are concatenated spatially in
|
||||
`obs_cam_keys` order, so the physical camera→slot mapping must match training:
|
||||
|
||||
| benchmark | `obs_cam_keys` (in order) | `camera_layout` |
|
||||
| --------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
|
||||
| LIBERO | `observation.images.image` (agentview / 3rd-person), `observation.images.image2` (eye-in-hand wrist) | `width_concat` (latents concatenated on width) |
|
||||
| RoboTwin | `observation.images.head_camera`, `observation.images.left_camera`, `observation.images.right_camera` | `robotwin_tshape` (full-res head below, two half-res wrists on top) |
|
||||
|
||||
The first camera is the exterior/head view and the rest are wrist views.
|
||||
|
||||
## Inference Hyperparameters (LIBERO)
|
||||
|
||||
| Key | Value |
|
||||
| -------------------------------------- | --------------------------------------------------------------------------------- |
|
||||
| height × width | 128 × 128 |
|
||||
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
|
||||
| action channels used | 0–6 (7-DoF arm + gripper) |
|
||||
| action_per_frame / frame_chunk_size | 4 / 4 |
|
||||
| attn_window | 30 |
|
||||
| video / action denoising steps | 20 / 50 |
|
||||
| guidance_scale / action_guidance_scale | 5 / 1 |
|
||||
| snr_shift / action_snr_shift | 5.0 / 0.05 |
|
||||
|
||||
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
|
||||
|
||||
## Notes
|
||||
|
||||
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
|
||||
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
|
||||
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
|
||||
roughly 18–24 GB of VRAM.
|
||||
|
||||
## License
|
||||
|
||||
LingBot-VA is released under Apache-2.0. See the
|
||||
[upstream repository](https://github.com/Robbyant/lingbot-va).
|
||||
+3
-7
@@ -146,8 +146,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
|
||||
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
@@ -215,11 +214,11 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
recap = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -286,7 +285,6 @@ all = [
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[lingbot_va]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -299,6 +297,7 @@ all = [
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[recap]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -378,9 +377,6 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
|
||||
# torch.nn.functional.
|
||||
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -768,9 +768,6 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
# must equal what SAPIEN actually renders.
|
||||
observation_height: int = 240
|
||||
observation_width: int = 320
|
||||
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
|
||||
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
|
||||
action_mode: str = "joint"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
@@ -787,8 +784,6 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.action_mode == "ee":
|
||||
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
|
||||
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
|
||||
for cam in cam_list:
|
||||
self.features[f"pixels/{cam}"] = PolicyFeature(
|
||||
@@ -831,7 +826,6 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
observation_height=self.observation_height,
|
||||
observation_width=self.observation_width,
|
||||
episode_length=self.episode_length,
|
||||
action_mode=self.action_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,42 +41,10 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
|
||||
"right_camera",
|
||||
)
|
||||
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
|
||||
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
|
||||
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
|
||||
EEF_ACTION_DIM = 16
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
DEFAULT_EPISODE_LENGTH = 300
|
||||
|
||||
|
||||
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a single-arm predicted delta pose onto the initial pose.
|
||||
|
||||
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
|
||||
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
|
||||
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
|
||||
"""
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
new_r = Rotation.from_quat(new_pose[3:7])
|
||||
init_r = Rotation.from_quat(init_pose[3:7])
|
||||
out_rot = (init_r * new_r).as_quat()
|
||||
out_trans = new_pose[:3] + init_pose[:3]
|
||||
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
|
||||
|
||||
|
||||
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
|
||||
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
|
||||
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
|
||||
out = np.concatenate([left, right])
|
||||
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
|
||||
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
|
||||
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
|
||||
return out
|
||||
|
||||
|
||||
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
|
||||
DEFAULT_CAMERA_H = 240
|
||||
DEFAULT_CAMERA_W = 320
|
||||
@@ -266,7 +234,6 @@ class RoboTwinEnv(gym.Env):
|
||||
observation_width: int | None = None,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
render_mode: str = "rgb_array",
|
||||
action_mode: str = "joint",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = task_name
|
||||
@@ -274,13 +241,6 @@ class RoboTwinEnv(gym.Env):
|
||||
self.task_description = task_name.replace("_", " ")
|
||||
self.episode_index = episode_index
|
||||
self._reset_stride = n_envs
|
||||
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
|
||||
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
|
||||
if action_mode not in ("joint", "ee"):
|
||||
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
|
||||
self.action_mode = action_mode
|
||||
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
|
||||
self._init_eef_pose: np.ndarray | None = None
|
||||
self.camera_names = list(camera_names)
|
||||
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
|
||||
# The YAML-driven lookup is deferred to reset() so construction doesn't
|
||||
@@ -311,7 +271,7 @@ class RoboTwinEnv(gym.Env):
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
@@ -357,18 +317,6 @@ class RoboTwinEnv(gym.Env):
|
||||
|
||||
return {"pixels": images, "agent_pos": joint_state}
|
||||
|
||||
def _read_eef_pose(self) -> np.ndarray:
|
||||
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
|
||||
assert self._env is not None, "_read_eef_pose called before _ensure_env()"
|
||||
ep = self._env.get_obs()["endpose"]
|
||||
pose = (
|
||||
list(ep["left_endpose"])
|
||||
+ [ep["left_gripper"]]
|
||||
+ list(ep["right_endpose"])
|
||||
+ [ep["right_gripper"]]
|
||||
)
|
||||
return np.asarray(pose, dtype=np.float64)
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
|
||||
self._ensure_env()
|
||||
super().reset(seed=seed)
|
||||
@@ -382,23 +330,16 @@ class RoboTwinEnv(gym.Env):
|
||||
self.episode_index += self._reset_stride
|
||||
self._step_count = 0
|
||||
|
||||
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
|
||||
if self.action_mode == "ee":
|
||||
self._init_eef_pose = self._read_eef_pose()
|
||||
|
||||
obs = self._get_obs()
|
||||
return obs, {"is_success": False, "task": self.task_name}
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
assert self._env is not None, "step() called before reset()"
|
||||
if action.ndim != 1 or action.shape[0] != self._action_dim:
|
||||
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
|
||||
|
||||
with torch.enable_grad():
|
||||
if self.action_mode == "ee":
|
||||
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
|
||||
self._env.take_action(ee_action, action_type="ee")
|
||||
elif hasattr(self._env, "take_action"):
|
||||
if hasattr(self._env, "take_action"):
|
||||
self._env.take_action(action)
|
||||
else:
|
||||
self._env.step(action)
|
||||
@@ -457,7 +398,6 @@ def _make_env_fns(
|
||||
observation_height: int,
|
||||
observation_width: int,
|
||||
episode_length: int,
|
||||
action_mode: str = "joint",
|
||||
) -> list[Callable[[], RoboTwinEnv]]:
|
||||
"""Return n_envs factory callables for a single task."""
|
||||
|
||||
@@ -470,7 +410,6 @@ def _make_env_fns(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
|
||||
return [partial(_make_one, i) for i in range(n_envs)]
|
||||
@@ -484,7 +423,6 @@ def create_robotwin_envs(
|
||||
observation_height: int = DEFAULT_CAMERA_H,
|
||||
observation_width: int = DEFAULT_CAMERA_W,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
action_mode: str = "joint",
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboTwin 2.0 environments.
|
||||
|
||||
@@ -535,7 +473,6 @@ def create_robotwin_envs(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
|
||||
@@ -83,28 +83,6 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Linear warmup followed by a constant learning rate.
|
||||
|
||||
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
|
||||
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int = 1000
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
warmup_steps = self.num_warmup_steps or 0
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
@@ -20,7 +20,6 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
@@ -45,7 +44,6 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"LingBotVAConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
|
||||
@@ -49,7 +49,6 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -163,10 +162,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
elif name == "lingbot_va":
|
||||
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -223,8 +218,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "lingbot_va":
|
||||
return LingBotVAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -455,14 +448,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, LingBotVAConfig):
|
||||
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
processors = make_lingbot_va_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/lingbot_va.mdx
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
|
||||
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
|
||||
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
|
||||
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
|
||||
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "LingBotVAPolicy":
|
||||
from .modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -1,168 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for the LingBot-VA policy.
|
||||
|
||||
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
|
||||
video-diffusion stack. It interleaves prediction of future video latents and robot
|
||||
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
|
||||
upstream repository (https://github.com/Robbyant/lingbot-va).
|
||||
|
||||
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
|
||||
and the ``transformer/config.json`` of the released checkpoints.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("lingbot_va")
|
||||
@dataclass
|
||||
class LingBotVAConfig(PreTrainedConfig):
|
||||
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
|
||||
|
||||
# Wan transformer architecture
|
||||
patch_size: tuple[int, int, int] = (1, 2, 2)
|
||||
num_attention_heads: int = 24
|
||||
attention_head_dim: int = 128
|
||||
in_channels: int = 48
|
||||
out_channels: int = 48
|
||||
action_dim: int = 30
|
||||
text_dim: int = 4096
|
||||
freq_dim: int = 256
|
||||
ffn_dim: int = 14336
|
||||
num_layers: int = 30
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
rope_max_seq_len: int = 1024
|
||||
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
|
||||
attn_mode: str = "torch"
|
||||
|
||||
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
|
||||
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
|
||||
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-base"
|
||||
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
|
||||
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
|
||||
text_encoder_device: str = "cpu"
|
||||
|
||||
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
|
||||
obs_cam_keys: list[str] = field(
|
||||
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
|
||||
)
|
||||
# Undo the LIBERO env processor's extra horizontal flip to match the model's training orientation.
|
||||
image_hflip: bool = False
|
||||
# Camera latent layout: "width_concat" (cameras concatenated on width; LIBERO) or
|
||||
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
|
||||
camera_layout: str = "width_concat"
|
||||
|
||||
# Inference hyperparameters (LIBERO defaults)
|
||||
n_obs_steps: int = 1
|
||||
height: int = 128
|
||||
width: int = 128
|
||||
action_per_frame: int = 4
|
||||
frame_chunk_size: int = 4
|
||||
attn_window: int = 30
|
||||
num_inference_steps: int = 20
|
||||
video_exec_step: int = -1
|
||||
action_num_inference_steps: int = 50
|
||||
guidance_scale: float = 5.0
|
||||
action_guidance_scale: float = 1.0
|
||||
snr_shift: float = 5.0
|
||||
action_snr_shift: float = 0.05
|
||||
max_sequence_length: int = 512 # UMT5 prompt length
|
||||
|
||||
# Subset of the 30-d action space used by the benchmark (LIBERO = 7-DoF). The action
|
||||
# (un)normalization quantiles live in the checkpoint's ``policy_postprocessor.json``, not here.
|
||||
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
|
||||
|
||||
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
|
||||
save_predicted_video: bool = False
|
||||
|
||||
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
|
||||
# quantile-(un)normalized inside the policy / dedicated processor steps.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
scheduler_warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.attn_mode not in ("torch", "flashattn", "flex"):
|
||||
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Number of single-step actions produced per autoregressive chunk."""
|
||||
return self.frame_chunk_size * self.action_per_frame
|
||||
|
||||
@property
|
||||
def n_action_steps(self) -> int:
|
||||
"""Number of actions executed before refilling (the whole chunk)."""
|
||||
return self.chunk_size
|
||||
|
||||
def validate_features(self) -> None:
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"LingBot-VA requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
|
||||
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
|
||||
|
||||
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,87 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pre/post-processor pipelines for the LingBot-VA policy.
|
||||
|
||||
The preprocessor passes inputs through (IDENTITY) and the postprocessor maps the policy's
|
||||
``[-1, 1]`` actions back to physical units with the built-in ``UnnormalizerProcessorStep``
|
||||
(QUANTILES) using per-channel q01/q99 restored from the checkpoint.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def make_lingbot_va_pre_post_processors(
|
||||
config: LingBotVAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build the pre/post processor pipelines for LingBot-VA."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# Unnormalize actions from [-1, 1] to physical units (QUANTILES) using q01/q99 restored from the checkpoint.
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.QUANTILES},
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig as DistributionalVFConfig,
|
||||
)
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"DistributionalVFConfig",
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
|
||||
+11
-1
@@ -1,4 +1,4 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,3 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"DistributionalVFConfig",
|
||||
"DistributionalVFRewardModel",
|
||||
"make_distributional_vf_pre_post_processors",
|
||||
]
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||
@dataclass
|
||||
class DistributionalVFConfig(RewardModelConfig):
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||
(Eq. 1 in the paper).
|
||||
|
||||
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||
about 670M params and initialize from the PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
# Backbone
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
num_hidden_layers: int = 6
|
||||
num_vision_layers: int = 13
|
||||
|
||||
# Distributional head
|
||||
num_value_bins: int = 201
|
||||
value_support_min: float = -1.0
|
||||
value_support_max: float = 0.0
|
||||
hl_gauss_sigma_ratio: float = 5.0
|
||||
|
||||
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||
target_method: str = "hl_gauss"
|
||||
|
||||
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||
# When False, terminal states use the same target method as non-terminal states.
|
||||
use_one_hot_terminal: bool = True
|
||||
|
||||
# Image
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 64
|
||||
|
||||
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||
init_from_actor_path: str = ""
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=1e-4,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.input_features:
|
||||
return
|
||||
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||
if not has_image:
|
||||
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||
+567
@@ -0,0 +1,567 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modeling for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
|
||||
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaRMSNorm,
|
||||
_gated_residual,
|
||||
_get_pi_gemma_decoder_layer_base,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
PiGemmaRMSNorm = None
|
||||
_gated_residual = None
|
||||
_get_pi_gemma_decoder_layer_base = None
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257152
|
||||
|
||||
|
||||
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||
"""Distributional value function model for RECAP.
|
||||
|
||||
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||
per-task normalized Monte Carlo returns.
|
||||
|
||||
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||
"""
|
||||
|
||||
name = "distributional_value_function"
|
||||
config_class = DistributionalVFConfig
|
||||
|
||||
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||
require_package("transformers", extra="recap")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||
|
||||
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||
base_config = get_gemma_config(config.paligemma_variant)
|
||||
hidden_dim = base_config.width
|
||||
mlp_dim = base_config.mlp_dim
|
||||
num_layers = config.num_hidden_layers
|
||||
|
||||
# HuggingFace GemmaConfig for transformer layers
|
||||
gemma_config = CONFIG_MAPPING["gemma"](
|
||||
head_dim=base_config.head_dim,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=mlp_dim,
|
||||
num_attention_heads=base_config.num_heads,
|
||||
num_hidden_layers=num_layers,
|
||||
num_key_value_heads=base_config.num_kv_heads,
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
)
|
||||
self.gemma_config = gemma_config
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_value_bins = config.num_value_bins
|
||||
|
||||
# Single learned [CLS] token for value prediction
|
||||
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||
|
||||
# Value projection head: Linear(hidden_dim, num_bins)
|
||||
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||
|
||||
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||
|
||||
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||
paligemma_config.text_config = gemma_config
|
||||
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||
paligemma_config.vision_config.intermediate_size = 4304
|
||||
paligemma_config.vision_config.projection_dim = 2048
|
||||
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||
|
||||
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||
self.vision_tower = paligemma_full.model.vision_tower
|
||||
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||
del paligemma_full
|
||||
|
||||
# Truncate vision tower to num_vision_layers
|
||||
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||
vision_encoder = self.vision_tower.vision_model.encoder
|
||||
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||
|
||||
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||
|
||||
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||
if config.init_from_actor_path:
|
||||
self._initialize_from_actor()
|
||||
|
||||
def _initialize_from_actor(self) -> None:
|
||||
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||
|
||||
Called on first training run only (when init_from_actor_path is set).
|
||||
"""
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||
actor_model = actor_policy.model
|
||||
|
||||
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||
source_language_model = paligemma_model.model.language_model
|
||||
|
||||
# Transformer components
|
||||
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||
num_layers = self.gemma_config.num_hidden_layers
|
||||
for i in range(num_layers):
|
||||
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||
|
||||
# Vision tower (truncate source first, then copy)
|
||||
source_vision_tower = paligemma_model.model.vision_tower
|
||||
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||
source_vision_tower.vision_model, "encoder"
|
||||
):
|
||||
source_encoder = source_vision_tower.vision_model.encoder
|
||||
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||
|
||||
# Multi-modal projector
|
||||
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||
|
||||
# Token embedding table
|
||||
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||
|
||||
del actor_policy
|
||||
|
||||
def embed_image(self, image: Tensor) -> Tensor:
|
||||
"""Embed images using the value function's SigLIP vision tower.
|
||||
|
||||
Args:
|
||||
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||
|
||||
Returns:
|
||||
[batch_size, num_patches, hidden_dim] projected image features.
|
||||
"""
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
image_outputs = self.vision_tower(image, return_dict=True)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
image_features = image_features / (self.hidden_dim**0.5)
|
||||
|
||||
if image_features.dtype != out_dtype:
|
||||
image_features = image_features.to(out_dtype)
|
||||
return image_features
|
||||
|
||||
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||
"""Embed text token IDs using the value function's token embedding table.
|
||||
|
||||
Args:
|
||||
token_ids: [batch_size, seq_len] integer token IDs
|
||||
|
||||
Returns:
|
||||
[batch_size, seq_len, hidden_dim] text embeddings
|
||||
"""
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||
"""Get [CLS] token embedding expanded to batch size.
|
||||
|
||||
Args:
|
||||
batch_size: number of samples in the batch.
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||
"""
|
||||
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||
|
||||
def forward_value(
|
||||
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||
) -> dict[str, Tensor]:
|
||||
"""Core forward pass through the distributional value function.
|
||||
|
||||
Args:
|
||||
vision_features: [batch_size, num_patches, hidden_dim]
|
||||
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, num_value_bins]
|
||||
probs: [batch_size, num_value_bins]
|
||||
value: [batch_size, 1]
|
||||
"""
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
device = text_embeddings.device
|
||||
|
||||
# Build sequence: [vision, text, CLS]
|
||||
cls_embedding = self._get_cls_embedding(batch_size)
|
||||
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||
|
||||
# Build causal attention mask
|
||||
vision_len = vision_features.shape[1]
|
||||
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||
|
||||
full_seq_len = full_padding_mask.shape[1]
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||
# Combine causal mask with padding mask
|
||||
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||
batch_size, 1, full_seq_len, full_seq_len
|
||||
)
|
||||
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
hidden_states_normed, gate = norm_output
|
||||
else:
|
||||
hidden_states_normed, gate = norm_output, None
|
||||
|
||||
input_shape = hidden_states_normed.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||
layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
layer.self_attn.scaling,
|
||||
)
|
||||
|
||||
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||
|
||||
if gate is not None:
|
||||
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||
else:
|
||||
projected_attention = hidden_states + projected_attention
|
||||
|
||||
after_attention_residual = projected_attention.clone()
|
||||
|
||||
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
mlp_input, gate = norm_output
|
||||
else:
|
||||
mlp_input, gate = norm_output, None
|
||||
|
||||
mlp_output = layer.mlp(mlp_input)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||
else:
|
||||
hidden_states = after_attention_residual + mlp_output
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
# Extract [CLS] token (last position in the sequence)
|
||||
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||
|
||||
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||
value_probs = F.softmax(value_logits, dim=-1)
|
||||
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||
|
||||
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||
"""HL-Gauss soft target distribution.
|
||||
|
||||
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||
|
||||
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||
arXiv:2403.03950
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
# Bin edges: half a bin-width outside the first/last center
|
||||
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||
self.num_value_bins - 1
|
||||
)
|
||||
support_edges = torch.linspace(
|
||||
self.config.value_support_min - bin_width / 2,
|
||||
self.config.value_support_max + bin_width / 2,
|
||||
self.num_value_bins + 1,
|
||||
device=target_value.device,
|
||||
dtype=target_value.dtype,
|
||||
)
|
||||
|
||||
# CDF of N(target, sigma^2) evaluated at each edge
|
||||
cdf_at_edges = 0.5 * (
|
||||
1.0
|
||||
+ torch.erf(
|
||||
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||
)
|
||||
) # [batch_size, num_bins + 1]
|
||||
|
||||
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||
|
||||
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||
|
||||
return bin_probabilities
|
||||
|
||||
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||
|
||||
Standard distributional RL projection from Bellemare et al. 2017.
|
||||
"A Distributional Perspective on Reinforcement Learning"
|
||||
arXiv:1707.06887
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||
|
||||
weight_upper = normalized_position - lower_bin_idx.float()
|
||||
weight_lower = upper_bin_idx.float() - normalized_position
|
||||
|
||||
same_bin = lower_bin_idx == upper_bin_idx
|
||||
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||
|
||||
batch_size = target_value.shape[0]
|
||||
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||
|
||||
return target_distribution
|
||||
|
||||
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||
"""One-hot target for terminal states (exact return, no smoothing).
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
nearest_bin_idx = torch.argmin(
|
||||
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||
)
|
||||
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||
|
||||
def compute_target_distribution(
|
||||
self,
|
||||
target_value: Tensor,
|
||||
is_terminal: Tensor,
|
||||
method: str = "hl_gauss",
|
||||
use_one_hot_terminal: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute target distribution using configured method.
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] scalar return targets
|
||||
is_terminal: [batch_size] boolean terminal flags
|
||||
method: "hl_gauss" or "dirac_delta"
|
||||
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||
(exact return, no smoothing). If False, all states use the same method.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution
|
||||
"""
|
||||
if method == "hl_gauss":
|
||||
base_distribution = self.hl_gauss_target(target_value)
|
||||
elif method == "dirac_delta":
|
||||
base_distribution = self.dirac_delta_target(target_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||
|
||||
if not use_one_hot_terminal:
|
||||
return base_distribution
|
||||
|
||||
terminal_distribution = self.one_hot_target(target_value)
|
||||
|
||||
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||
|
||||
The batch is expected to be preprocessed by the processor pipeline.
|
||||
Keys expected in batch:
|
||||
- observation.images.*: [B, C, H, W] preprocessed images
|
||||
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||
- is_terminal: [B] boolean terminal flags
|
||||
|
||||
Returns:
|
||||
(loss, output_dict) where loss is scalar cross-entropy
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
# Get first image key from batch
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
mc_return = batch["mc_return"]
|
||||
is_terminal = batch["is_terminal"]
|
||||
|
||||
# Embed observations
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
# Forward through value function transformer
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
value_logits = vf_output["logits"]
|
||||
predicted_value = vf_output["value"]
|
||||
|
||||
# Compute target distribution
|
||||
target_distribution = self.compute_target_distribution(
|
||||
mc_return,
|
||||
is_terminal,
|
||||
method=self.config.target_method,
|
||||
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||
)
|
||||
|
||||
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||
|
||||
output_dict = {
|
||||
"loss": loss.item(),
|
||||
"predicted_value_mean": predicted_value.mean().item(),
|
||||
"mc_return_mean": mc_return.mean().item(),
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||
|
||||
Args:
|
||||
batch: preprocessed batch with images and tokenized text
|
||||
|
||||
Returns:
|
||||
[batch_size] tensor of predicted values V(s)
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||
+235
@@ -0,0 +1,235 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processor for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||
|
||||
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||
They are dataset columns read directly from the batch in the model's forward().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||
@dataclass
|
||||
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||
"""Format the task string for the distributional value function.
|
||||
|
||||
The value function receives only visual observations and task text.
|
||||
Builds prompt: "Task: {task}."
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
tasks = complementary_data.get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
full_prompts = []
|
||||
for task in tasks:
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
full_prompts.append(f"Task: {cleaned_text}.")
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||
@dataclass
|
||||
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||
|
||||
Expects float images in [0, 1].
|
||||
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||
- Scale to [-1, 1] for SigLIP
|
||||
"""
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
image_keys: tuple[str, ...] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||
|
||||
image_keys = self.image_keys or tuple(
|
||||
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError(
|
||||
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||
)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for image_key in image_keys:
|
||||
image = new_observation[image_key]
|
||||
if not isinstance(image, Tensor):
|
||||
image = to_tensor(image)
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != self.image_resolution:
|
||||
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
|
||||
new_observation[image_key] = image
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_resolution": self.image_resolution,
|
||||
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||
return tuple(
|
||||
feature_name
|
||||
for feature_name, feature in config.input_features.items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
)
|
||||
|
||||
|
||||
def make_distributional_vf_pre_post_processors(
|
||||
config: DistributionalVFConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre/post processors for the distributional value function.
|
||||
|
||||
Preprocessor steps:
|
||||
1. Rename observations (no-op by default)
|
||||
2. Add a batch dimension
|
||||
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||
4. Format task prompt: "Task: {task}."
|
||||
5. Tokenize with the PaliGemma tokenizer
|
||||
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||
7. Move tensors to the configured device
|
||||
|
||||
Training targets (mc_return, is_terminal) are not processed here.
|
||||
The model reads them directly from the batch in forward().
|
||||
|
||||
The postprocessor is a no-op because the value function does not need
|
||||
action postprocessing.
|
||||
"""
|
||||
image_keys = _visual_image_keys(config)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DistributionalVFPrepareTaskPromptStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DistributionalVFImagePreprocessorStep(
|
||||
image_resolution=config.image_resolution,
|
||||
image_keys=image_keys or None,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
elif name == "distributional_value_function":
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
elif reward_type == "distributional_value_function":
|
||||
return DistributionalVFConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
|
||||
return make_distributional_vf_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -105,7 +105,6 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
predicted_frames_callback: Callable[[PreTrainedPolicy], None] | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -135,9 +134,6 @@ def rollout(
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
predicted_frames_callback: Optional callback invoked after every ``select_action`` with the policy
|
||||
itself. World-model policies (e.g. LingBot-VA) stash their decoded predicted video frames on
|
||||
``policy.last_predicted_frames``; this lets the caller collect them to save predicted-video MP4s.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
@@ -188,8 +184,6 @@ def rollout(
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
if predicted_frames_callback is not None:
|
||||
predicted_frames_callback(policy)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {ACTION: action}
|
||||
@@ -279,7 +273,6 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
save_predicted_video: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -298,11 +291,6 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
|
||||
save_predicted_video = save_predicted_video or bool(
|
||||
getattr(getattr(policy, "config", None), "save_predicted_video", False)
|
||||
)
|
||||
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
exc = ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
@@ -346,21 +334,6 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if save_predicted_video:
|
||||
if not videos_dir:
|
||||
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
|
||||
predicted_video_paths: list[str] = []
|
||||
n_predicted_rendered = 0
|
||||
|
||||
# Collects the policy's decoded predicted-video frames across a rollout (world-model policies only).
|
||||
def collect_predicted_frames(policy: PreTrainedPolicy):
|
||||
frames = getattr(policy, "last_predicted_frames", None)
|
||||
if frames is not None:
|
||||
pred_frames.append(
|
||||
np.asarray(frames.detach().to("cpu")) if hasattr(frames, "detach") else np.asarray(frames)
|
||||
)
|
||||
policy.last_predicted_frames = None
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
@@ -372,9 +345,6 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if save_predicted_video:
|
||||
pred_frames: list[np.ndarray] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
@@ -391,7 +361,6 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
predicted_frames_callback=collect_predicted_frames if save_predicted_video else None,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -457,25 +426,6 @@ def eval_policy(
|
||||
threads.append(thread)
|
||||
n_episodes_rendered += 1
|
||||
|
||||
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
|
||||
if save_predicted_video and len(pred_frames) > 0:
|
||||
# pred_frames is a list of [F, H, W, C] uint8 stacks emitted on chunk refills; concat over time.
|
||||
predicted_video = np.concatenate(pred_frames, axis=0)
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
|
||||
predicted_video_paths.append(str(predicted_video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(predicted_video_path),
|
||||
predicted_video,
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
n_predicted_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
)
|
||||
@@ -519,9 +469,6 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
info["video_paths"] = video_paths
|
||||
|
||||
if save_predicted_video:
|
||||
info["predicted_video_paths"] = predicted_video_paths
|
||||
|
||||
return info
|
||||
|
||||
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def make_config(**overrides) -> LingBotVAConfig:
|
||||
kwargs = {"device": "cpu"}
|
||||
kwargs.update(overrides)
|
||||
return LingBotVAConfig(**kwargs)
|
||||
|
||||
|
||||
def test_registered_in_choice_registry() -> None:
|
||||
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
|
||||
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
|
||||
|
||||
|
||||
def test_type_property() -> None:
|
||||
assert make_config().type == "lingbot_va"
|
||||
|
||||
|
||||
def test_chunk_size_and_action_steps() -> None:
|
||||
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
|
||||
assert cfg.chunk_size == 16
|
||||
assert cfg.n_action_steps == 16
|
||||
assert cfg.action_delta_indices == list(range(16))
|
||||
assert cfg.observation_delta_indices is None
|
||||
assert cfg.reward_delta_indices is None
|
||||
|
||||
|
||||
def test_optimizer_and_scheduler_presets() -> None:
|
||||
cfg = make_config()
|
||||
opt = cfg.get_optimizer_preset()
|
||||
assert opt.lr == cfg.optimizer_lr
|
||||
sched = cfg.get_scheduler_preset()
|
||||
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
|
||||
|
||||
|
||||
def test_validate_features_sets_action_feature() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
assert ACTION in cfg.output_features
|
||||
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
|
||||
|
||||
|
||||
def test_validate_features_no_visual_raises() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {}
|
||||
cfg.output_features = {}
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
cfg.validate_features()
|
||||
|
||||
|
||||
def test_invalid_attn_mode_raises() -> None:
|
||||
with pytest.raises(ValueError, match="attn_mode"):
|
||||
make_config(attn_mode="banana")
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def test_make_policy_config_returns_lingbot_va() -> None:
|
||||
cfg = make_policy_config("lingbot_va", device="cpu")
|
||||
assert isinstance(cfg, LingBotVAConfig)
|
||||
|
||||
|
||||
def test_get_policy_class_resolves_lazily() -> None:
|
||||
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
|
||||
pytest.importorskip("diffusers")
|
||||
pytest.importorskip("transformers")
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
|
||||
cls = get_policy_class("lingbot_va")
|
||||
assert cls.name == "lingbot_va"
|
||||
assert cls.config_class is LingBotVAConfig
|
||||
@@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
|
||||
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
|
||||
FlowMatchScheduler,
|
||||
data_seq_to_patch,
|
||||
get_mesh_id,
|
||||
)
|
||||
|
||||
|
||||
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
assert sch.timesteps.shape == (20,)
|
||||
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
|
||||
assert torch.all(diffs <= 0) # decreasing
|
||||
|
||||
|
||||
def test_flow_match_scheduler_step_preserves_shape() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.zeros(1, 48, 4, 8, 16)
|
||||
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
|
||||
assert out.shape == sample.shape
|
||||
|
||||
|
||||
def test_flow_match_scheduler_add_noise() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.randn(1, 48, 4, 8, 16)
|
||||
noise = torch.randn_like(sample)
|
||||
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
|
||||
assert noisy.shape == sample.shape
|
||||
|
||||
|
||||
def test_get_mesh_id_latent_shape() -> None:
|
||||
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
|
||||
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
|
||||
|
||||
|
||||
def test_get_mesh_id_action_shape() -> None:
|
||||
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
|
||||
assert grid.shape == (4, 4 * 4 * 1)
|
||||
# Action rows for h/w are sentinel -1.
|
||||
assert torch.all(grid[1] < 0)
|
||||
assert torch.all(grid[2] < 0)
|
||||
|
||||
|
||||
def test_data_seq_to_patch_roundtrip_shape() -> None:
|
||||
b, f, h, w, c = 1, 4, 8, 16, 48
|
||||
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
|
||||
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
|
||||
assert out.shape == (b, c, f, h, w)
|
||||
|
||||
|
||||
def test_training_step_reduces_loss_tiny_flex() -> None:
|
||||
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
|
||||
|
||||
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
import pytest
|
||||
|
||||
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
cfg = LingBotVAConfig(
|
||||
attn_mode="flex",
|
||||
dtype="bfloat16",
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
action_dim=8,
|
||||
text_dim=32,
|
||||
freq_dim=64,
|
||||
ffn_dim=64,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=24,
|
||||
num_layers=2,
|
||||
frame_chunk_size=2,
|
||||
action_per_frame=4,
|
||||
used_action_channel_ids=[0, 1, 2, 3],
|
||||
obs_cam_keys=[f"{OBS_IMAGES}.image"],
|
||||
device="cuda",
|
||||
)
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
|
||||
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
|
||||
cfg.validate_features()
|
||||
|
||||
policy = LingBotVAPolicy(cfg).to("cuda")
|
||||
policy.train()
|
||||
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
|
||||
|
||||
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
|
||||
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
|
||||
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
|
||||
amask = torch.zeros(cfg.action_dim, device="cuda")
|
||||
amask[cfg.used_action_channel_ids] = 1.0
|
||||
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
|
||||
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
|
||||
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
|
||||
opt.step()
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline, UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def _make_config() -> LingBotVAConfig:
|
||||
cfg = LingBotVAConfig(device="cpu")
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
def test_make_pre_post_processors_names_and_steps() -> None:
|
||||
cfg = _make_config()
|
||||
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
|
||||
# Actions are unnormalized by the standard built-in quantile unnormalizer.
|
||||
assert any(isinstance(s, UnnormalizerProcessorStep) for s in post.steps)
|
||||
|
||||
|
||||
def test_freshly_built_postprocessor_is_identity() -> None:
|
||||
# Without action stats the quantile unnormalizer is a no-op (identity passthrough): the real
|
||||
# per-benchmark q01/q99 are restored from the saved checkpoint on load, not hardcoded here.
|
||||
cfg = _make_config()
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
normed = torch.tensor([[0.3, -0.5, 1.0, -1.0, 0.0, 0.7, -0.2]])
|
||||
assert torch.allclose(post(normed), normed, atol=1e-6)
|
||||
|
||||
|
||||
def test_postprocessor_quantile_unnormalization() -> None:
|
||||
# QUANTILES unnormalize maps [-1, 1] -> [q01, q99]: -1 -> q01, +1 -> q99.
|
||||
cfg = _make_config()
|
||||
q01 = [-1.0, -0.5, 0.0, -1.0, -1.0, -1.0, -1.0]
|
||||
q99 = [1.0, 0.5, 2.0, 1.0, 1.0, 1.0, 1.0]
|
||||
stats = {ACTION: {"q01": q01, "q99": q99}}
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=stats)
|
||||
out_lo = post(torch.full((1, 7), -1.0))
|
||||
out_hi = post(torch.full((1, 7), 1.0))
|
||||
assert torch.allclose(out_lo, torch.tensor(q01).unsqueeze(0), atol=1e-4)
|
||||
assert torch.allclose(out_hi, torch.tensor(q99).unsqueeze(0), atol=1e-4)
|
||||
|
||||
|
||||
def test_postprocessor_stats_survive_save_load(tmp_path) -> None:
|
||||
# Regression guard for the Hub mechanism: the q01/q99 stats live in the saved post-processor
|
||||
# state and must round-trip through save_pretrained / from_pretrained.
|
||||
cfg = _make_config()
|
||||
q01 = [-0.6, -0.8, -0.9, -0.1, -0.15, -0.25, -1.0]
|
||||
q99 = [0.9, 0.85, 0.9, 0.17, 0.18, 0.34, 1.0]
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats={ACTION: {"q01": q01, "q99": q99}})
|
||||
post.save_pretrained(tmp_path)
|
||||
loaded = PolicyProcessorPipeline.from_pretrained(
|
||||
tmp_path,
|
||||
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
out = loaded(torch.full((1, 7), -1.0))
|
||||
assert torch.allclose(out, torch.tensor(q01).unsqueeze(0), atol=1e-4)
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -0,0 +1,518 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP's distributional value function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_BINS = 201
|
||||
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||
|
||||
|
||||
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||
defaults = {
|
||||
"init_from_actor_path": "",
|
||||
"device": "cpu",
|
||||
"image_resolution": (224, 224),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
config = DistributionalVFConfig(**defaults)
|
||||
config.input_features = {
|
||||
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_model():
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel(_make_config())
|
||||
|
||||
|
||||
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
return {
|
||||
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
def test_config_registered_in_reward_model_registry():
|
||||
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "distributional_value_function" in known
|
||||
|
||||
|
||||
def test_factory_returns_correct_class():
|
||||
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("distributional_value_function")
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
assert cls is DistributionalVFRewardModel
|
||||
|
||||
|
||||
def test_make_reward_model_config_factory():
|
||||
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||
from lerobot.rewards.factory import make_reward_model_config
|
||||
|
||||
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||
assert isinstance(config, DistributionalVFConfig)
|
||||
assert config.num_value_bins == 101
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_sums_to_one():
|
||||
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_non_negative():
|
||||
"""HL-Gauss target probabilities are all non-negative."""
|
||||
model = _make_model()
|
||||
targets = torch.linspace(-1.0, 0.0, 10)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert (dist >= 0).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_expected_value_matches():
|
||||
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_handles_2d_input():
|
||||
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (2, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_sums_to_one():
|
||||
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
assert dist.shape == (5, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_at_most_two_nonzero():
|
||||
"""Dirac delta places probability on at most two adjacent bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.7523, -0.0013])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
for i in range(2):
|
||||
assert (dist[i] > 0).sum() <= 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_expected_value_matches():
|
||||
"""E[V] under Dirac delta distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_boundary_values_clamped():
|
||||
"""Values outside support are clamped to boundary bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-1.5, 0.5])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||
assert dist[0, 0] == 1.0
|
||||
assert dist[1, -1] == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_single_nonzero():
|
||||
"""One-hot target has exactly one non-zero bin per sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
for i in range(4):
|
||||
assert (dist[i] > 0).sum() == 1
|
||||
assert dist[i].sum() == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_nearest_bin():
|
||||
"""One-hot target activates the bin closest to the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
hot_idx = dist[0].argmax()
|
||||
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_terminal_gets_one_hot():
|
||||
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||
is_terminal = torch.tensor([False, True, False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||
assert (dist[1] > 0).sum() == 1
|
||||
assert (dist[3] > 0).sum() == 1
|
||||
assert (dist[0] > 0).sum() > 2
|
||||
assert (dist[2] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_no_terminal_override_when_disabled():
|
||||
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3])
|
||||
is_terminal = torch.tensor([False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||
)
|
||||
|
||||
assert (dist[1] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_has_expected_components():
|
||||
"""Model scaffold contains all architectural components."""
|
||||
model = _make_model()
|
||||
|
||||
assert hasattr(model, "vision_tower")
|
||||
assert hasattr(model, "multi_modal_projector")
|
||||
assert hasattr(model, "token_embedding")
|
||||
assert hasattr(model, "layers")
|
||||
assert hasattr(model, "value_head")
|
||||
assert hasattr(model, "cls_embedding")
|
||||
assert hasattr(model, "norm")
|
||||
assert hasattr(model, "rotary_emb")
|
||||
assert hasattr(model, "bin_centers")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_bin_centers_shape():
|
||||
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||
model = _make_model()
|
||||
assert model.bin_centers.shape == (NUM_BINS,)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_layer_count():
|
||||
"""Transformer has num_hidden_layers (6) layers."""
|
||||
model = _make_model()
|
||||
assert len(model.layers) == 6
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_value_head_output_dim():
|
||||
"""Value head outputs num_value_bins logits."""
|
||||
model = _make_model()
|
||||
assert model.value_head.out_features == NUM_BINS
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_returns_loss_and_dict():
|
||||
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, output_dict = model.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert "loss" in output_dict
|
||||
assert "predicted_value_mean" in output_dict
|
||||
assert "mc_return_mean" in output_dict
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_loss_is_positive():
|
||||
"""Cross-entropy loss is strictly positive for random weights."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_returns_correct_shape():
|
||||
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=3)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert values.shape == (3,)
|
||||
assert values.dtype == torch.float32
|
||||
assert torch.isfinite(values).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_values_in_support_range():
|
||||
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=8)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert (values >= -1.0 - 0.01).all()
|
||||
assert (values <= 0.0 + 0.01).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_processor_pipeline_produces_expected_keys():
|
||||
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
config = _make_config()
|
||||
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||
|
||||
raw_batch = {
|
||||
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
|
||||
processed = preprocessor(raw_batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert IMAGE_KEY in processed
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_value_head():
|
||||
"""Backprop produces non-zero gradients on the value head."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.value_head.weight.grad is not None
|
||||
assert not torch.all(model.value_head.weight.grad == 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_cls_embedding():
|
||||
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.cls_embedding.grad is not None
|
||||
assert not torch.all(model.cls_embedding.grad == 0)
|
||||
|
||||
|
||||
def test_config_requires_visual_feature():
|
||||
"""validate_features raises if no VISUAL feature is present."""
|
||||
config = DistributionalVFConfig(init_from_actor_path="")
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="VISUAL"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_config_passes_with_visual_feature():
|
||||
"""validate_features succeeds when a VISUAL feature is present."""
|
||||
config = _make_config()
|
||||
config.validate_features()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||
"""Saved model can be loaded back with identical weights."""
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
model = _make_model()
|
||||
model._save_pretrained(tmp_path)
|
||||
|
||||
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
orig_sd = model.state_dict()
|
||||
loaded_sd = loaded.state_dict()
|
||||
|
||||
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||
for key in orig_sd:
|
||||
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.min() >= -1.0 - 1e-5
|
||||
assert image.max() <= 1.0 + 1e-5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_resizes_with_pad():
|
||||
"""Image preprocessor resizes non-square images to target resolution."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.shape[1:3] == (224, 224)
|
||||
|
||||
|
||||
def test_task_prompt_formats_correctly():
|
||||
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: pick up the cup."
|
||||
|
||||
|
||||
def test_task_prompt_handles_string_input():
|
||||
"""Task prompt step accepts a plain string (not just a list)."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: open drawer."
|
||||
|
||||
|
||||
def test_task_prompt_raises_on_missing_task():
|
||||
"""Task prompt step raises ValueError when task key is absent."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No task found"):
|
||||
step(transition)
|
||||
Reference in New Issue
Block a user