mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
feat(lingbot_va): implement training / fine-tuning (flow-matching loss)
- Implement LingBotVAPolicy.forward(): dual-stream flow-matching training loss (latent + action, timestep-weighted, action-masked) ported from upstream train.py; VAE-encodes camera clips, UMT5-encodes the task, noises both streams, runs the block-causal flex-attention training pass (forward_train). - training_loss_from_streams() core + _build_training_streams() data prep (action scatter into the 30-d space, multi-frame VAE encode incl. robotwin_tshape). - get_optim_params returns only trainable transformer params (LoRA/PEFT friendly); VAE/UMT5 stay frozen. Training needs attn_mode='flex'. - Add a tiny-config single-training-step test (forward->loss->backward->AdamW) and a Training/fine-tuning section in the docs. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -32,10 +32,8 @@ fed back into the KV cache as the chunk is executed (closed-loop world modeling)
|
||||
- Autoregressive dual-stream inference behind the standard `select_action` interface
|
||||
(single-environment eval, `--eval.batch_size=1`).
|
||||
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
|
||||
- Evaluation with `lerobot-eval` on the LIBERO benchmark.
|
||||
|
||||
Training (the flow-matching dual-stream loss + latent dataset) is part of a follow-up
|
||||
training port and is not yet wired into `lerobot-train`.
|
||||
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
|
||||
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -105,6 +103,32 @@ Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicte
|
||||
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
|
||||
The same flag works for the periodic eval during `lerobot-train`.
|
||||
|
||||
## Training / fine-tuning
|
||||
|
||||
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
|
||||
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
|
||||
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
|
||||
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
|
||||
with a linear-warmup-then-constant schedule (matching upstream).
|
||||
|
||||
Requirements:
|
||||
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
|
||||
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
|
||||
- The full 5B DiT does not fit a single 24–32 GB GPU under AdamW; fine-tune with **LoRA**
|
||||
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
|
||||
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=pepijn223/lingbot_va_libero_long --policy.attn_mode=flex \
|
||||
--policy.use_peft=true \
|
||||
--dataset.repo_id=<your LeRobot-format dataset> \
|
||||
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
|
||||
```
|
||||
|
||||
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
|
||||
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
|
||||
|
||||
## Inference Hyperparameters (LIBERO)
|
||||
|
||||
| Key | Value |
|
||||
|
||||
@@ -51,6 +51,7 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
@@ -1132,6 +1133,17 @@ def _torch_dtype(name: str) -> torch.dtype:
|
||||
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[name]
|
||||
|
||||
|
||||
def _sample_timestep_id(
|
||||
batch_size: int = 1,
|
||||
min_timestep_bd: float = 0.0,
|
||||
max_timestep_bd: float = 1.0,
|
||||
num_train_timesteps: int = 1000,
|
||||
) -> torch.Tensor:
|
||||
"""Sample per-frame flow-matching timestep ids (upstream ``utils.sample_timestep_id``)."""
|
||||
u = torch.rand(size=[batch_size]) * (max_timestep_bd - min_timestep_bd) + min_timestep_bd
|
||||
return (u * num_train_timesteps).clamp(min=0, max=num_train_timesteps - 1).to(torch.int64)
|
||||
|
||||
|
||||
class LingBotVAPolicy(PreTrainedPolicy):
|
||||
"""LeRobot wrapper for the LingBot-VA autoregressive video-action world model."""
|
||||
|
||||
@@ -1226,8 +1238,9 @@ class LingBotVAPolicy(PreTrainedPolicy):
|
||||
# PreTrainedPolicy API
|
||||
# ------------------------------------------------------------------
|
||||
def get_optim_params(self) -> dict:
|
||||
# Only the transformer is trainable; the VAE / text encoder stay frozen.
|
||||
return self.transformer.parameters()
|
||||
# Only the transformer is trainable; the VAE / text encoder stay frozen (kept outside the
|
||||
# nn.Module registry). With PEFT/LoRA this naturally returns just the adapter params.
|
||||
return [p for p in self.transformer.parameters() if p.requires_grad]
|
||||
|
||||
def reset(self):
|
||||
"""Reset all per-episode streaming state (KV cache, queues, frame counter)."""
|
||||
@@ -1270,19 +1283,215 @@ class LingBotVAPolicy(PreTrainedPolicy):
|
||||
if "streaming_vae_half" in self._frozen:
|
||||
self._frozen["streaming_vae_half"].clear_cache()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Training loss. Implemented in the LingBot-VA training PR (Phase 7).
|
||||
# ------------------------------------------------------------------
|
||||
# Training (flow-matching dual-stream loss). Requires attn_mode="flex".
|
||||
# ------------------------------------------------------------------
|
||||
def _ensure_train_schedulers(self):
|
||||
if getattr(self, "_train_sched_latent", None) is None:
|
||||
cfg = self.config
|
||||
self._train_sched_latent = FlowMatchScheduler(
|
||||
shift=cfg.snr_shift, sigma_min=0.0, extra_one_step=True
|
||||
)
|
||||
self._train_sched_latent.set_timesteps(1000, training=True)
|
||||
self._train_sched_action = FlowMatchScheduler(
|
||||
shift=cfg.action_snr_shift, sigma_min=0.0, extra_one_step=True
|
||||
)
|
||||
self._train_sched_action.set_timesteps(1000, training=True)
|
||||
|
||||
The flow-matching dual-stream loss needs the pre-extracted latent dataset
|
||||
(see ``LatentLeRobotDataset`` upstream) and ``attn_mode='flex'``; it is intentionally
|
||||
not part of this inference-focused integration.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"LingBot-VA training (flow-matching dual-stream loss) is part of the training port "
|
||||
"(Phase 7 / PR #2) and is not implemented in this inference integration. "
|
||||
"Use this policy for evaluation / inference."
|
||||
@torch.no_grad()
|
||||
def _add_noise_stream(self, latent, scheduler, action_mask, action_mode, noisy_cond_prob):
|
||||
"""Flow-matching noising of one stream (port of upstream ``Trainer._add_noise``)."""
|
||||
device = latent.device
|
||||
B, _C, F, _H, _W = latent.shape
|
||||
p = self.config.patch_size
|
||||
patch_f, patch_h, patch_w = (1, 1, 1) if action_mode else (p[0], p[1], p[2])
|
||||
|
||||
ts_ids = _sample_timestep_id(F, num_train_timesteps=scheduler.num_train_timesteps)
|
||||
noise = torch.zeros_like(latent).normal_()
|
||||
timesteps = scheduler.timesteps[ts_ids].to(device)
|
||||
noisy_latents = scheduler.add_noise(latent, noise, timesteps, t_dim=2)
|
||||
targets = scheduler.training_target(latent, noise, timesteps)
|
||||
|
||||
grid_id = (
|
||||
get_mesh_id(
|
||||
latent.shape[-3] // patch_f,
|
||||
latent.shape[-2] // patch_h,
|
||||
latent.shape[-1] // patch_w,
|
||||
t=1 if action_mode else 0,
|
||||
f_w=1,
|
||||
f_shift=0,
|
||||
action=action_mode,
|
||||
)
|
||||
.to(device)[None]
|
||||
.repeat(B, 1, 1)
|
||||
)
|
||||
|
||||
if torch.rand(1).item() < noisy_cond_prob:
|
||||
cond_ids = _sample_timestep_id(
|
||||
F, min_timestep_bd=0.5, max_timestep_bd=1.0, num_train_timesteps=scheduler.num_train_timesteps
|
||||
)
|
||||
cond_noise = torch.zeros_like(latent).normal_()
|
||||
cond_timesteps = scheduler.timesteps[cond_ids].to(device)
|
||||
latent = scheduler.add_noise(latent, cond_noise, cond_timesteps, t_dim=2)
|
||||
else:
|
||||
cond_timesteps = torch.zeros_like(timesteps)
|
||||
|
||||
if action_mask is not None:
|
||||
noisy_latents = noisy_latents * action_mask.float()
|
||||
targets = targets * action_mask.float()
|
||||
latent = latent * action_mask.float()
|
||||
|
||||
return {
|
||||
"timesteps": timesteps[None].repeat(B, 1),
|
||||
"noisy_latents": noisy_latents,
|
||||
"targets": targets,
|
||||
"latent": latent,
|
||||
"cond_timesteps": cond_timesteps[None].repeat(B, 1),
|
||||
"grid_id": grid_id,
|
||||
}
|
||||
|
||||
def _flow_matching_loss(self, input_dict, pred):
|
||||
"""Dual-stream flow-matching loss (port of upstream ``Trainer.compute_loss``)."""
|
||||
latent_pred, action_pred = pred
|
||||
ld, ad = input_dict["latent_dict"], input_dict["action_dict"]
|
||||
action_pred = rearrange(action_pred, "b (f n) c -> b c f n 1", f=ad["targets"].shape[-3])
|
||||
latent_pred = data_seq_to_patch(
|
||||
self.config.patch_size,
|
||||
latent_pred,
|
||||
ld["targets"].shape[-3],
|
||||
ld["targets"].shape[-2],
|
||||
ld["targets"].shape[-1],
|
||||
batch_size=latent_pred.shape[0],
|
||||
)
|
||||
Bn, Fn = ld["timesteps"].shape
|
||||
lw = self._train_sched_latent.training_weight(ld["timesteps"].flatten()).reshape(Bn, Fn)
|
||||
aw = self._train_sched_action.training_weight(ad["timesteps"].flatten()).reshape(Bn, Fn)
|
||||
|
||||
latent_loss = F.mse_loss(latent_pred.float(), ld["targets"].float().detach(), reduction="none")
|
||||
latent_loss = (
|
||||
(latent_loss * lw[:, None, :, None, None]).permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1)
|
||||
)
|
||||
latent_loss = (latent_loss.sum(dim=1) / (torch.ones_like(latent_loss).sum(dim=1) + 1e-6)).mean()
|
||||
|
||||
amask = ad["actions_mask"].float()
|
||||
action_loss = F.mse_loss(action_pred.float(), ad["targets"].float().detach(), reduction="none")
|
||||
action_loss = (
|
||||
(action_loss * aw[:, None, :, None, None] * amask).permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1)
|
||||
)
|
||||
amask_f = amask.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1)
|
||||
action_loss = (action_loss.sum(dim=1) / (amask_f.sum(dim=1) + 1e-6)).mean()
|
||||
return latent_loss, action_loss
|
||||
|
||||
def training_loss_from_streams(self, latents, actions, actions_mask, text_emb):
|
||||
"""Core dual-stream training loss given prepared latents / actions / text embeddings.
|
||||
|
||||
``latents``: ``[B, in_channels, F, h, w]`` (normalized video latents).
|
||||
``actions`` / ``actions_mask``: ``[B, action_dim, F, action_per_frame, 1]``.
|
||||
``text_emb``: ``[B, seq_len, text_dim]``. Returns ``(loss, {latent_loss, action_loss})``.
|
||||
"""
|
||||
if self.config.attn_mode != "flex":
|
||||
raise ValueError(
|
||||
"LingBot-VA training requires attn_mode='flex' (block-causal flow-matching masks). "
|
||||
"Load/convert the policy with --policy.attn_mode=flex for training/fine-tuning."
|
||||
)
|
||||
self._ensure_train_schedulers()
|
||||
latent_dict = self._add_noise_stream(
|
||||
latents, self._train_sched_latent, action_mask=None, action_mode=False, noisy_cond_prob=0.5
|
||||
)
|
||||
action_dict = self._add_noise_stream(
|
||||
actions, self._train_sched_action, action_mask=actions_mask, action_mode=True, noisy_cond_prob=0.0
|
||||
)
|
||||
latent_dict["text_emb"] = text_emb
|
||||
action_dict["text_emb"] = text_emb
|
||||
action_dict["actions_mask"] = actions_mask
|
||||
input_dict = {
|
||||
"latent_dict": latent_dict,
|
||||
"action_dict": action_dict,
|
||||
"chunk_size": int(torch.randint(1, 5, (1,)).item()),
|
||||
"window_size": int(torch.randint(4, 65, (1,)).item()),
|
||||
}
|
||||
pred = self.transformer(input_dict, train_mode=True)
|
||||
latent_loss, action_loss = self._flow_matching_loss(input_dict, pred)
|
||||
loss = latent_loss + action_loss
|
||||
return loss, {"latent_loss": latent_loss.detach(), "action_loss": action_loss.detach()}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Training forward: dual-stream flow-matching loss.
|
||||
|
||||
Builds the (video-latent, action, text) training streams from a LeRobot batch
|
||||
(VAE-encoding the camera frames and UMT5-encoding the task), then runs the flow-matching
|
||||
dual-stream loss. Requires the policy to be built with ``attn_mode='flex'``.
|
||||
"""
|
||||
self._ensure_frozen_modules()
|
||||
latents, actions, actions_mask, text_emb = self._build_training_streams(batch)
|
||||
return self.training_loss_from_streams(latents, actions, actions_mask, text_emb)
|
||||
|
||||
@torch.no_grad()
|
||||
def _build_training_streams(self, batch):
|
||||
"""Build (latents, actions, actions_mask, text_emb) from a LeRobot training batch.
|
||||
|
||||
Camera frames per ``obs_cam_keys`` are expected as a temporal clip ``[B, C, T, H, W]`` (or
|
||||
``[B, T, C, H, W]``); they are VAE-encoded into ``F = T / temporal_downsample`` latent frames.
|
||||
Actions ``[B, F*action_per_frame, n_used]`` are scattered into the model's ``action_dim`` space.
|
||||
"""
|
||||
cfg = self.config
|
||||
device = cfg.device
|
||||
# ---- text embeddings ----
|
||||
task = batch.get("task")
|
||||
if isinstance(task, str):
|
||||
task = [task]
|
||||
text_emb = self._get_t5_prompt_embeds(list(task), cfg.max_sequence_length)
|
||||
|
||||
# ---- video latents (VAE-encode the camera clips) ----
|
||||
latents = self._encode_training_latents(batch)
|
||||
|
||||
# ---- actions -> [B, action_dim, F, action_per_frame, 1] ----
|
||||
act = batch[ACTION].to(device) # [B, F*apf, n_used]
|
||||
B = act.shape[0]
|
||||
used = cfg.used_action_channel_ids
|
||||
apf, Fc = cfg.action_per_frame, cfg.frame_chunk_size
|
||||
act = act[:, : Fc * apf].reshape(B, Fc, apf, len(used)).permute(0, 3, 1, 2) # [B, n_used, F, apf]
|
||||
full = act.new_zeros(B, cfg.action_dim, Fc, apf)
|
||||
idx = torch.as_tensor(used, device=device)
|
||||
full[:, idx] = act
|
||||
actions = full.unsqueeze(-1).to(self.dtype) # [B, action_dim, F, apf, 1]
|
||||
mask = torch.zeros(cfg.action_dim, device=device, dtype=self.dtype)
|
||||
mask[idx] = 1.0
|
||||
actions_mask = mask.view(1, -1, 1, 1, 1).expand_as(actions)
|
||||
return latents, actions, actions_mask, text_emb
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_training_latents(self, batch) -> Tensor:
|
||||
"""VAE-encode the per-camera training clips into normalized video latents [B, C, F, h, w]."""
|
||||
vae_device = next(self._vae.parameters()).device
|
||||
|
||||
def _clip(key):
|
||||
x = batch[key].to(vae_device)
|
||||
if x.dim() == 4: # [B, C, H, W] -> single frame clip
|
||||
x = x.unsqueeze(2)
|
||||
elif x.shape[1] not in (1, 3) and x.shape[2] in (1, 3): # [B, T, C, H, W] -> [B, C, T, H, W]
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
return x.contiguous()
|
||||
|
||||
def _encode(x, size):
|
||||
b, c, t = x.shape[:3]
|
||||
x = F.interpolate(x.flatten(0, 1).float(), size=size, mode="bilinear", align_corners=False)
|
||||
x = (x.view(b, c, t, *size) * 2.0 - 1.0).to(self.dtype)
|
||||
mu = self._vae.encode(x).latent_dist.mode() # [B, z_dim, F, h, w]
|
||||
mean = torch.tensor(self._vae.config.latents_mean).view(1, -1, 1, 1, 1).to(mu.device)
|
||||
inv_std = (1.0 / torch.tensor(self._vae.config.latents_std)).view(1, -1, 1, 1, 1).to(mu.device)
|
||||
return ((mu.float() - mean) * inv_std).to(mu)
|
||||
|
||||
keys = self.config.obs_cam_keys
|
||||
if self.config.camera_layout == "robotwin_tshape":
|
||||
h, w = self.config.height, self.config.width
|
||||
head = _encode(_clip(keys[0]), (h, w))
|
||||
left = _encode(_clip(keys[1]), (h // 2, w // 2))
|
||||
right = _encode(_clip(keys[2]), (h // 2, w // 2))
|
||||
return torch.cat([torch.cat([left, right], dim=-1), head], dim=-2).to(self.config.device)
|
||||
per_cam = [_encode(_clip(k), (self.config.height, self.config.width)) for k in keys]
|
||||
return torch.cat(per_cam, dim=-1).to(self.config.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
"""Return one action, refilling the chunk (and feeding back observed keyframes) as needed.
|
||||
|
||||
@@ -73,3 +73,59 @@ def test_data_seq_to_patch_roundtrip_shape() -> None:
|
||||
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
|
||||
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
|
||||
assert out.shape == (b, c, f, h, w)
|
||||
|
||||
|
||||
def test_training_step_reduces_loss_tiny_flex() -> None:
|
||||
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
|
||||
|
||||
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
import pytest
|
||||
|
||||
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
cfg = LingBotVAConfig(
|
||||
attn_mode="flex",
|
||||
dtype="bfloat16",
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
action_dim=8,
|
||||
text_dim=32,
|
||||
freq_dim=64,
|
||||
ffn_dim=64,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=24,
|
||||
num_layers=2,
|
||||
frame_chunk_size=2,
|
||||
action_per_frame=4,
|
||||
used_action_channel_ids=[0, 1, 2, 3],
|
||||
obs_cam_keys=[f"{OBS_IMAGES}.image"],
|
||||
device="cuda",
|
||||
)
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
|
||||
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
|
||||
cfg.validate_features()
|
||||
|
||||
policy = LingBotVAPolicy(cfg).to("cuda")
|
||||
policy.train()
|
||||
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
|
||||
|
||||
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
|
||||
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
|
||||
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
|
||||
amask = torch.zeros(cfg.action_dim, device="cuda")
|
||||
amask[cfg.used_action_channel_ids] = 1.0
|
||||
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
|
||||
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
|
||||
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
|
||||
opt.step()
|
||||
|
||||
Reference in New Issue
Block a user