feat(lingbot_va): implement training / fine-tuning (flow-matching loss)

- Implement LingBotVAPolicy.forward(): dual-stream flow-matching training loss
  (latent + action, timestep-weighted, action-masked) ported from upstream train.py;
  VAE-encodes camera clips, UMT5-encodes the task, noises both streams, runs the
  block-causal flex-attention training pass (forward_train).
- training_loss_from_streams() core + _build_training_streams() data prep (action
  scatter into the 30-d space, multi-frame VAE encode incl. robotwin_tshape).
- get_optim_params returns only trainable transformer params (LoRA/PEFT friendly);
  VAE/UMT5 stay frozen. Training needs attn_mode='flex'.
- Add a tiny-config single-training-step test (forward->loss->backward->AdamW) and a
  Training/fine-tuning section in the docs.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-06 15:38:41 +02:00
parent e3deff00ad
commit 71aacda05e
3 changed files with 305 additions and 16 deletions
+28 -4
View File
@@ -32,10 +32,8 @@ fed back into the KV cache as the chunk is executed (closed-loop world modeling)
- Autoregressive dual-stream inference behind the standard `select_action` interface
(single-environment eval, `--eval.batch_size=1`).
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
- Evaluation with `lerobot-eval` on the LIBERO benchmark.
Training (the flow-matching dual-stream loss + latent dataset) is part of a follow-up
training port and is not yet wired into `lerobot-train`.
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
## Installation
@@ -105,6 +103,32 @@ Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicte
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
The same flag works for the periodic eval during `lerobot-train`.
## Training / fine-tuning
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
with a linear-warmup-then-constant schedule (matching upstream).
Requirements:
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
- The full 5B DiT does not fit a single 2432 GB GPU under AdamW; fine-tune with **LoRA**
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
```bash
lerobot-train \
--policy.path=pepijn223/lingbot_va_libero_long --policy.attn_mode=flex \
--policy.use_peft=true \
--dataset.repo_id=<your LeRobot-format dataset> \
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
```
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
## Inference Hyperparameters (LIBERO)
| Key | Value |
@@ -51,6 +51,7 @@ from einops import rearrange
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import require_package
from .configuration_lingbot_va import LingBotVAConfig
@@ -1132,6 +1133,17 @@ def _torch_dtype(name: str) -> torch.dtype:
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[name]
def _sample_timestep_id(
batch_size: int = 1,
min_timestep_bd: float = 0.0,
max_timestep_bd: float = 1.0,
num_train_timesteps: int = 1000,
) -> torch.Tensor:
"""Sample per-frame flow-matching timestep ids (upstream ``utils.sample_timestep_id``)."""
u = torch.rand(size=[batch_size]) * (max_timestep_bd - min_timestep_bd) + min_timestep_bd
return (u * num_train_timesteps).clamp(min=0, max=num_train_timesteps - 1).to(torch.int64)
class LingBotVAPolicy(PreTrainedPolicy):
"""LeRobot wrapper for the LingBot-VA autoregressive video-action world model."""
@@ -1226,8 +1238,9 @@ class LingBotVAPolicy(PreTrainedPolicy):
# PreTrainedPolicy API
# ------------------------------------------------------------------
def get_optim_params(self) -> dict:
# Only the transformer is trainable; the VAE / text encoder stay frozen.
return self.transformer.parameters()
# Only the transformer is trainable; the VAE / text encoder stay frozen (kept outside the
# nn.Module registry). With PEFT/LoRA this naturally returns just the adapter params.
return [p for p in self.transformer.parameters() if p.requires_grad]
def reset(self):
"""Reset all per-episode streaming state (KV cache, queues, frame counter)."""
@@ -1270,19 +1283,215 @@ class LingBotVAPolicy(PreTrainedPolicy):
if "streaming_vae_half" in self._frozen:
self._frozen["streaming_vae_half"].clear_cache()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Training loss. Implemented in the LingBot-VA training PR (Phase 7).
# ------------------------------------------------------------------
# Training (flow-matching dual-stream loss). Requires attn_mode="flex".
# ------------------------------------------------------------------
def _ensure_train_schedulers(self):
if getattr(self, "_train_sched_latent", None) is None:
cfg = self.config
self._train_sched_latent = FlowMatchScheduler(
shift=cfg.snr_shift, sigma_min=0.0, extra_one_step=True
)
self._train_sched_latent.set_timesteps(1000, training=True)
self._train_sched_action = FlowMatchScheduler(
shift=cfg.action_snr_shift, sigma_min=0.0, extra_one_step=True
)
self._train_sched_action.set_timesteps(1000, training=True)
The flow-matching dual-stream loss needs the pre-extracted latent dataset
(see ``LatentLeRobotDataset`` upstream) and ``attn_mode='flex'``; it is intentionally
not part of this inference-focused integration.
"""
raise NotImplementedError(
"LingBot-VA training (flow-matching dual-stream loss) is part of the training port "
"(Phase 7 / PR #2) and is not implemented in this inference integration. "
"Use this policy for evaluation / inference."
@torch.no_grad()
def _add_noise_stream(self, latent, scheduler, action_mask, action_mode, noisy_cond_prob):
"""Flow-matching noising of one stream (port of upstream ``Trainer._add_noise``)."""
device = latent.device
B, _C, F, _H, _W = latent.shape
p = self.config.patch_size
patch_f, patch_h, patch_w = (1, 1, 1) if action_mode else (p[0], p[1], p[2])
ts_ids = _sample_timestep_id(F, num_train_timesteps=scheduler.num_train_timesteps)
noise = torch.zeros_like(latent).normal_()
timesteps = scheduler.timesteps[ts_ids].to(device)
noisy_latents = scheduler.add_noise(latent, noise, timesteps, t_dim=2)
targets = scheduler.training_target(latent, noise, timesteps)
grid_id = (
get_mesh_id(
latent.shape[-3] // patch_f,
latent.shape[-2] // patch_h,
latent.shape[-1] // patch_w,
t=1 if action_mode else 0,
f_w=1,
f_shift=0,
action=action_mode,
)
.to(device)[None]
.repeat(B, 1, 1)
)
if torch.rand(1).item() < noisy_cond_prob:
cond_ids = _sample_timestep_id(
F, min_timestep_bd=0.5, max_timestep_bd=1.0, num_train_timesteps=scheduler.num_train_timesteps
)
cond_noise = torch.zeros_like(latent).normal_()
cond_timesteps = scheduler.timesteps[cond_ids].to(device)
latent = scheduler.add_noise(latent, cond_noise, cond_timesteps, t_dim=2)
else:
cond_timesteps = torch.zeros_like(timesteps)
if action_mask is not None:
noisy_latents = noisy_latents * action_mask.float()
targets = targets * action_mask.float()
latent = latent * action_mask.float()
return {
"timesteps": timesteps[None].repeat(B, 1),
"noisy_latents": noisy_latents,
"targets": targets,
"latent": latent,
"cond_timesteps": cond_timesteps[None].repeat(B, 1),
"grid_id": grid_id,
}
def _flow_matching_loss(self, input_dict, pred):
"""Dual-stream flow-matching loss (port of upstream ``Trainer.compute_loss``)."""
latent_pred, action_pred = pred
ld, ad = input_dict["latent_dict"], input_dict["action_dict"]
action_pred = rearrange(action_pred, "b (f n) c -> b c f n 1", f=ad["targets"].shape[-3])
latent_pred = data_seq_to_patch(
self.config.patch_size,
latent_pred,
ld["targets"].shape[-3],
ld["targets"].shape[-2],
ld["targets"].shape[-1],
batch_size=latent_pred.shape[0],
)
Bn, Fn = ld["timesteps"].shape
lw = self._train_sched_latent.training_weight(ld["timesteps"].flatten()).reshape(Bn, Fn)
aw = self._train_sched_action.training_weight(ad["timesteps"].flatten()).reshape(Bn, Fn)
latent_loss = F.mse_loss(latent_pred.float(), ld["targets"].float().detach(), reduction="none")
latent_loss = (
(latent_loss * lw[:, None, :, None, None]).permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1)
)
latent_loss = (latent_loss.sum(dim=1) / (torch.ones_like(latent_loss).sum(dim=1) + 1e-6)).mean()
amask = ad["actions_mask"].float()
action_loss = F.mse_loss(action_pred.float(), ad["targets"].float().detach(), reduction="none")
action_loss = (
(action_loss * aw[:, None, :, None, None] * amask).permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1)
)
amask_f = amask.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1)
action_loss = (action_loss.sum(dim=1) / (amask_f.sum(dim=1) + 1e-6)).mean()
return latent_loss, action_loss
def training_loss_from_streams(self, latents, actions, actions_mask, text_emb):
"""Core dual-stream training loss given prepared latents / actions / text embeddings.
``latents``: ``[B, in_channels, F, h, w]`` (normalized video latents).
``actions`` / ``actions_mask``: ``[B, action_dim, F, action_per_frame, 1]``.
``text_emb``: ``[B, seq_len, text_dim]``. Returns ``(loss, {latent_loss, action_loss})``.
"""
if self.config.attn_mode != "flex":
raise ValueError(
"LingBot-VA training requires attn_mode='flex' (block-causal flow-matching masks). "
"Load/convert the policy with --policy.attn_mode=flex for training/fine-tuning."
)
self._ensure_train_schedulers()
latent_dict = self._add_noise_stream(
latents, self._train_sched_latent, action_mask=None, action_mode=False, noisy_cond_prob=0.5
)
action_dict = self._add_noise_stream(
actions, self._train_sched_action, action_mask=actions_mask, action_mode=True, noisy_cond_prob=0.0
)
latent_dict["text_emb"] = text_emb
action_dict["text_emb"] = text_emb
action_dict["actions_mask"] = actions_mask
input_dict = {
"latent_dict": latent_dict,
"action_dict": action_dict,
"chunk_size": int(torch.randint(1, 5, (1,)).item()),
"window_size": int(torch.randint(4, 65, (1,)).item()),
}
pred = self.transformer(input_dict, train_mode=True)
latent_loss, action_loss = self._flow_matching_loss(input_dict, pred)
loss = latent_loss + action_loss
return loss, {"latent_loss": latent_loss.detach(), "action_loss": action_loss.detach()}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Training forward: dual-stream flow-matching loss.
Builds the (video-latent, action, text) training streams from a LeRobot batch
(VAE-encoding the camera frames and UMT5-encoding the task), then runs the flow-matching
dual-stream loss. Requires the policy to be built with ``attn_mode='flex'``.
"""
self._ensure_frozen_modules()
latents, actions, actions_mask, text_emb = self._build_training_streams(batch)
return self.training_loss_from_streams(latents, actions, actions_mask, text_emb)
@torch.no_grad()
def _build_training_streams(self, batch):
"""Build (latents, actions, actions_mask, text_emb) from a LeRobot training batch.
Camera frames per ``obs_cam_keys`` are expected as a temporal clip ``[B, C, T, H, W]`` (or
``[B, T, C, H, W]``); they are VAE-encoded into ``F = T / temporal_downsample`` latent frames.
Actions ``[B, F*action_per_frame, n_used]`` are scattered into the model's ``action_dim`` space.
"""
cfg = self.config
device = cfg.device
# ---- text embeddings ----
task = batch.get("task")
if isinstance(task, str):
task = [task]
text_emb = self._get_t5_prompt_embeds(list(task), cfg.max_sequence_length)
# ---- video latents (VAE-encode the camera clips) ----
latents = self._encode_training_latents(batch)
# ---- actions -> [B, action_dim, F, action_per_frame, 1] ----
act = batch[ACTION].to(device) # [B, F*apf, n_used]
B = act.shape[0]
used = cfg.used_action_channel_ids
apf, Fc = cfg.action_per_frame, cfg.frame_chunk_size
act = act[:, : Fc * apf].reshape(B, Fc, apf, len(used)).permute(0, 3, 1, 2) # [B, n_used, F, apf]
full = act.new_zeros(B, cfg.action_dim, Fc, apf)
idx = torch.as_tensor(used, device=device)
full[:, idx] = act
actions = full.unsqueeze(-1).to(self.dtype) # [B, action_dim, F, apf, 1]
mask = torch.zeros(cfg.action_dim, device=device, dtype=self.dtype)
mask[idx] = 1.0
actions_mask = mask.view(1, -1, 1, 1, 1).expand_as(actions)
return latents, actions, actions_mask, text_emb
@torch.no_grad()
def _encode_training_latents(self, batch) -> Tensor:
"""VAE-encode the per-camera training clips into normalized video latents [B, C, F, h, w]."""
vae_device = next(self._vae.parameters()).device
def _clip(key):
x = batch[key].to(vae_device)
if x.dim() == 4: # [B, C, H, W] -> single frame clip
x = x.unsqueeze(2)
elif x.shape[1] not in (1, 3) and x.shape[2] in (1, 3): # [B, T, C, H, W] -> [B, C, T, H, W]
x = x.permute(0, 2, 1, 3, 4)
return x.contiguous()
def _encode(x, size):
b, c, t = x.shape[:3]
x = F.interpolate(x.flatten(0, 1).float(), size=size, mode="bilinear", align_corners=False)
x = (x.view(b, c, t, *size) * 2.0 - 1.0).to(self.dtype)
mu = self._vae.encode(x).latent_dist.mode() # [B, z_dim, F, h, w]
mean = torch.tensor(self._vae.config.latents_mean).view(1, -1, 1, 1, 1).to(mu.device)
inv_std = (1.0 / torch.tensor(self._vae.config.latents_std)).view(1, -1, 1, 1, 1).to(mu.device)
return ((mu.float() - mean) * inv_std).to(mu)
keys = self.config.obs_cam_keys
if self.config.camera_layout == "robotwin_tshape":
h, w = self.config.height, self.config.width
head = _encode(_clip(keys[0]), (h, w))
left = _encode(_clip(keys[1]), (h // 2, w // 2))
right = _encode(_clip(keys[2]), (h // 2, w // 2))
return torch.cat([torch.cat([left, right], dim=-1), head], dim=-2).to(self.config.device)
per_cam = [_encode(_clip(k), (self.config.height, self.config.width)) for k in keys]
return torch.cat(per_cam, dim=-1).to(self.config.device)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Return one action, refilling the chunk (and feeding back observed keyframes) as needed.
+56
View File
@@ -73,3 +73,59 @@ def test_data_seq_to_patch_roundtrip_shape() -> None:
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
assert out.shape == (b, c, f, h, w)
def test_training_step_reduces_loss_tiny_flex() -> None:
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
"""
if not torch.cuda.is_available():
import pytest
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES
cfg = LingBotVAConfig(
attn_mode="flex",
dtype="bfloat16",
in_channels=16,
out_channels=16,
action_dim=8,
text_dim=32,
freq_dim=64,
ffn_dim=64,
num_attention_heads=2,
attention_head_dim=24,
num_layers=2,
frame_chunk_size=2,
action_per_frame=4,
used_action_channel_ids=[0, 1, 2, 3],
obs_cam_keys=[f"{OBS_IMAGES}.image"],
device="cuda",
)
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
cfg.validate_features()
policy = LingBotVAPolicy(cfg).to("cuda")
policy.train()
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
amask = torch.zeros(cfg.action_dim, device="cuda")
amask[cfg.used_action_channel_ids] = 1.0
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
loss.backward()
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
opt.step()