diff --git a/docs/source/lingbot_va.mdx b/docs/source/lingbot_va.mdx index 4fef37230..5d7fd3304 100644 --- a/docs/source/lingbot_va.mdx +++ b/docs/source/lingbot_va.mdx @@ -32,10 +32,8 @@ fed back into the KV cache as the chunk is executed (closed-loop world modeling) - Autoregressive dual-stream inference behind the standard `select_action` interface (single-environment eval, `--eval.batch_size=1`). - Opt-in saving of the policy's **predicted (imagined) videos** during eval / training. -- Evaluation with `lerobot-eval` on the LIBERO benchmark. - -Training (the flow-matching dual-stream loss + latent dataset) is part of a follow-up -training port and is not yet wired into `lerobot-train`. +- Evaluation with `lerobot-eval` on LIBERO and RoboTwin. +- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below. ## Installation @@ -105,6 +103,32 @@ Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicte latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos. The same flag works for the periodic eval during `lerobot-train`. +## Training / fine-tuning + +`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss +(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes +the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the +transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW +with a linear-warmup-then-constant schedule (matching upstream). + +Requirements: +- The block-causal masks use PyTorch **flex-attention**, so build the policy with + `--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only). +- The full 5B DiT does not fit a single 24–32 GB GPU under AdamW; fine-tune with **LoRA** + (`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the + trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen. + +```bash +lerobot-train \ + --policy.path=pepijn223/lingbot_va_libero_long --policy.attn_mode=flex \ + --policy.use_peft=true \ + --dataset.repo_id= \ + --batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va +``` + +The dataset must provide camera clips (a temporal window per camera, VAE-encoded to +`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item. + ## Inference Hyperparameters (LIBERO) | Key | Value | diff --git a/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py b/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py index 57734b1c7..de39dfe92 100644 --- a/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py +++ b/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py @@ -51,6 +51,7 @@ from einops import rearrange from torch import Tensor from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION from lerobot.utils.import_utils import require_package from .configuration_lingbot_va import LingBotVAConfig @@ -1132,6 +1133,17 @@ def _torch_dtype(name: str) -> torch.dtype: return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[name] +def _sample_timestep_id( + batch_size: int = 1, + min_timestep_bd: float = 0.0, + max_timestep_bd: float = 1.0, + num_train_timesteps: int = 1000, +) -> torch.Tensor: + """Sample per-frame flow-matching timestep ids (upstream ``utils.sample_timestep_id``).""" + u = torch.rand(size=[batch_size]) * (max_timestep_bd - min_timestep_bd) + min_timestep_bd + return (u * num_train_timesteps).clamp(min=0, max=num_train_timesteps - 1).to(torch.int64) + + class LingBotVAPolicy(PreTrainedPolicy): """LeRobot wrapper for the LingBot-VA autoregressive video-action world model.""" @@ -1226,8 +1238,9 @@ class LingBotVAPolicy(PreTrainedPolicy): # PreTrainedPolicy API # ------------------------------------------------------------------ def get_optim_params(self) -> dict: - # Only the transformer is trainable; the VAE / text encoder stay frozen. - return self.transformer.parameters() + # Only the transformer is trainable; the VAE / text encoder stay frozen (kept outside the + # nn.Module registry). With PEFT/LoRA this naturally returns just the adapter params. + return [p for p in self.transformer.parameters() if p.requires_grad] def reset(self): """Reset all per-episode streaming state (KV cache, queues, frame counter).""" @@ -1270,19 +1283,215 @@ class LingBotVAPolicy(PreTrainedPolicy): if "streaming_vae_half" in self._frozen: self._frozen["streaming_vae_half"].clear_cache() - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: - """Training loss. Implemented in the LingBot-VA training PR (Phase 7). + # ------------------------------------------------------------------ + # Training (flow-matching dual-stream loss). Requires attn_mode="flex". + # ------------------------------------------------------------------ + def _ensure_train_schedulers(self): + if getattr(self, "_train_sched_latent", None) is None: + cfg = self.config + self._train_sched_latent = FlowMatchScheduler( + shift=cfg.snr_shift, sigma_min=0.0, extra_one_step=True + ) + self._train_sched_latent.set_timesteps(1000, training=True) + self._train_sched_action = FlowMatchScheduler( + shift=cfg.action_snr_shift, sigma_min=0.0, extra_one_step=True + ) + self._train_sched_action.set_timesteps(1000, training=True) - The flow-matching dual-stream loss needs the pre-extracted latent dataset - (see ``LatentLeRobotDataset`` upstream) and ``attn_mode='flex'``; it is intentionally - not part of this inference-focused integration. - """ - raise NotImplementedError( - "LingBot-VA training (flow-matching dual-stream loss) is part of the training port " - "(Phase 7 / PR #2) and is not implemented in this inference integration. " - "Use this policy for evaluation / inference." + @torch.no_grad() + def _add_noise_stream(self, latent, scheduler, action_mask, action_mode, noisy_cond_prob): + """Flow-matching noising of one stream (port of upstream ``Trainer._add_noise``).""" + device = latent.device + B, _C, F, _H, _W = latent.shape + p = self.config.patch_size + patch_f, patch_h, patch_w = (1, 1, 1) if action_mode else (p[0], p[1], p[2]) + + ts_ids = _sample_timestep_id(F, num_train_timesteps=scheduler.num_train_timesteps) + noise = torch.zeros_like(latent).normal_() + timesteps = scheduler.timesteps[ts_ids].to(device) + noisy_latents = scheduler.add_noise(latent, noise, timesteps, t_dim=2) + targets = scheduler.training_target(latent, noise, timesteps) + + grid_id = ( + get_mesh_id( + latent.shape[-3] // patch_f, + latent.shape[-2] // patch_h, + latent.shape[-1] // patch_w, + t=1 if action_mode else 0, + f_w=1, + f_shift=0, + action=action_mode, + ) + .to(device)[None] + .repeat(B, 1, 1) ) + if torch.rand(1).item() < noisy_cond_prob: + cond_ids = _sample_timestep_id( + F, min_timestep_bd=0.5, max_timestep_bd=1.0, num_train_timesteps=scheduler.num_train_timesteps + ) + cond_noise = torch.zeros_like(latent).normal_() + cond_timesteps = scheduler.timesteps[cond_ids].to(device) + latent = scheduler.add_noise(latent, cond_noise, cond_timesteps, t_dim=2) + else: + cond_timesteps = torch.zeros_like(timesteps) + + if action_mask is not None: + noisy_latents = noisy_latents * action_mask.float() + targets = targets * action_mask.float() + latent = latent * action_mask.float() + + return { + "timesteps": timesteps[None].repeat(B, 1), + "noisy_latents": noisy_latents, + "targets": targets, + "latent": latent, + "cond_timesteps": cond_timesteps[None].repeat(B, 1), + "grid_id": grid_id, + } + + def _flow_matching_loss(self, input_dict, pred): + """Dual-stream flow-matching loss (port of upstream ``Trainer.compute_loss``).""" + latent_pred, action_pred = pred + ld, ad = input_dict["latent_dict"], input_dict["action_dict"] + action_pred = rearrange(action_pred, "b (f n) c -> b c f n 1", f=ad["targets"].shape[-3]) + latent_pred = data_seq_to_patch( + self.config.patch_size, + latent_pred, + ld["targets"].shape[-3], + ld["targets"].shape[-2], + ld["targets"].shape[-1], + batch_size=latent_pred.shape[0], + ) + Bn, Fn = ld["timesteps"].shape + lw = self._train_sched_latent.training_weight(ld["timesteps"].flatten()).reshape(Bn, Fn) + aw = self._train_sched_action.training_weight(ad["timesteps"].flatten()).reshape(Bn, Fn) + + latent_loss = F.mse_loss(latent_pred.float(), ld["targets"].float().detach(), reduction="none") + latent_loss = ( + (latent_loss * lw[:, None, :, None, None]).permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1) + ) + latent_loss = (latent_loss.sum(dim=1) / (torch.ones_like(latent_loss).sum(dim=1) + 1e-6)).mean() + + amask = ad["actions_mask"].float() + action_loss = F.mse_loss(action_pred.float(), ad["targets"].float().detach(), reduction="none") + action_loss = ( + (action_loss * aw[:, None, :, None, None] * amask).permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1) + ) + amask_f = amask.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1) + action_loss = (action_loss.sum(dim=1) / (amask_f.sum(dim=1) + 1e-6)).mean() + return latent_loss, action_loss + + def training_loss_from_streams(self, latents, actions, actions_mask, text_emb): + """Core dual-stream training loss given prepared latents / actions / text embeddings. + + ``latents``: ``[B, in_channels, F, h, w]`` (normalized video latents). + ``actions`` / ``actions_mask``: ``[B, action_dim, F, action_per_frame, 1]``. + ``text_emb``: ``[B, seq_len, text_dim]``. Returns ``(loss, {latent_loss, action_loss})``. + """ + if self.config.attn_mode != "flex": + raise ValueError( + "LingBot-VA training requires attn_mode='flex' (block-causal flow-matching masks). " + "Load/convert the policy with --policy.attn_mode=flex for training/fine-tuning." + ) + self._ensure_train_schedulers() + latent_dict = self._add_noise_stream( + latents, self._train_sched_latent, action_mask=None, action_mode=False, noisy_cond_prob=0.5 + ) + action_dict = self._add_noise_stream( + actions, self._train_sched_action, action_mask=actions_mask, action_mode=True, noisy_cond_prob=0.0 + ) + latent_dict["text_emb"] = text_emb + action_dict["text_emb"] = text_emb + action_dict["actions_mask"] = actions_mask + input_dict = { + "latent_dict": latent_dict, + "action_dict": action_dict, + "chunk_size": int(torch.randint(1, 5, (1,)).item()), + "window_size": int(torch.randint(4, 65, (1,)).item()), + } + pred = self.transformer(input_dict, train_mode=True) + latent_loss, action_loss = self._flow_matching_loss(input_dict, pred) + loss = latent_loss + action_loss + return loss, {"latent_loss": latent_loss.detach(), "action_loss": action_loss.detach()} + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """Training forward: dual-stream flow-matching loss. + + Builds the (video-latent, action, text) training streams from a LeRobot batch + (VAE-encoding the camera frames and UMT5-encoding the task), then runs the flow-matching + dual-stream loss. Requires the policy to be built with ``attn_mode='flex'``. + """ + self._ensure_frozen_modules() + latents, actions, actions_mask, text_emb = self._build_training_streams(batch) + return self.training_loss_from_streams(latents, actions, actions_mask, text_emb) + + @torch.no_grad() + def _build_training_streams(self, batch): + """Build (latents, actions, actions_mask, text_emb) from a LeRobot training batch. + + Camera frames per ``obs_cam_keys`` are expected as a temporal clip ``[B, C, T, H, W]`` (or + ``[B, T, C, H, W]``); they are VAE-encoded into ``F = T / temporal_downsample`` latent frames. + Actions ``[B, F*action_per_frame, n_used]`` are scattered into the model's ``action_dim`` space. + """ + cfg = self.config + device = cfg.device + # ---- text embeddings ---- + task = batch.get("task") + if isinstance(task, str): + task = [task] + text_emb = self._get_t5_prompt_embeds(list(task), cfg.max_sequence_length) + + # ---- video latents (VAE-encode the camera clips) ---- + latents = self._encode_training_latents(batch) + + # ---- actions -> [B, action_dim, F, action_per_frame, 1] ---- + act = batch[ACTION].to(device) # [B, F*apf, n_used] + B = act.shape[0] + used = cfg.used_action_channel_ids + apf, Fc = cfg.action_per_frame, cfg.frame_chunk_size + act = act[:, : Fc * apf].reshape(B, Fc, apf, len(used)).permute(0, 3, 1, 2) # [B, n_used, F, apf] + full = act.new_zeros(B, cfg.action_dim, Fc, apf) + idx = torch.as_tensor(used, device=device) + full[:, idx] = act + actions = full.unsqueeze(-1).to(self.dtype) # [B, action_dim, F, apf, 1] + mask = torch.zeros(cfg.action_dim, device=device, dtype=self.dtype) + mask[idx] = 1.0 + actions_mask = mask.view(1, -1, 1, 1, 1).expand_as(actions) + return latents, actions, actions_mask, text_emb + + @torch.no_grad() + def _encode_training_latents(self, batch) -> Tensor: + """VAE-encode the per-camera training clips into normalized video latents [B, C, F, h, w].""" + vae_device = next(self._vae.parameters()).device + + def _clip(key): + x = batch[key].to(vae_device) + if x.dim() == 4: # [B, C, H, W] -> single frame clip + x = x.unsqueeze(2) + elif x.shape[1] not in (1, 3) and x.shape[2] in (1, 3): # [B, T, C, H, W] -> [B, C, T, H, W] + x = x.permute(0, 2, 1, 3, 4) + return x.contiguous() + + def _encode(x, size): + b, c, t = x.shape[:3] + x = F.interpolate(x.flatten(0, 1).float(), size=size, mode="bilinear", align_corners=False) + x = (x.view(b, c, t, *size) * 2.0 - 1.0).to(self.dtype) + mu = self._vae.encode(x).latent_dist.mode() # [B, z_dim, F, h, w] + mean = torch.tensor(self._vae.config.latents_mean).view(1, -1, 1, 1, 1).to(mu.device) + inv_std = (1.0 / torch.tensor(self._vae.config.latents_std)).view(1, -1, 1, 1, 1).to(mu.device) + return ((mu.float() - mean) * inv_std).to(mu) + + keys = self.config.obs_cam_keys + if self.config.camera_layout == "robotwin_tshape": + h, w = self.config.height, self.config.width + head = _encode(_clip(keys[0]), (h, w)) + left = _encode(_clip(keys[1]), (h // 2, w // 2)) + right = _encode(_clip(keys[2]), (h // 2, w // 2)) + return torch.cat([torch.cat([left, right], dim=-1), head], dim=-2).to(self.config.device) + per_cam = [_encode(_clip(k), (self.config.height, self.config.width)) for k in keys] + return torch.cat(per_cam, dim=-1).to(self.config.device) + @torch.no_grad() def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: """Return one action, refilling the chunk (and feeding back observed keyframes) as needed. diff --git a/tests/policies/lingbot_va/test_modules.py b/tests/policies/lingbot_va/test_modules.py index b0d32cd9e..c169835ee 100644 --- a/tests/policies/lingbot_va/test_modules.py +++ b/tests/policies/lingbot_va/test_modules.py @@ -73,3 +73,59 @@ def test_data_seq_to_patch_roundtrip_shape() -> None: seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c) out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b) assert out.shape == (b, c, f, h, w) + + +def test_training_step_reduces_loss_tiny_flex() -> None: + """End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config. + + Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support. + """ + if not torch.cuda.is_available(): + import pytest + + pytest.skip("training step test requires a CUDA GPU (flex-attention)") + + from lerobot.configs.types import FeatureType, PolicyFeature + from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig + from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy + from lerobot.utils.constants import ACTION, OBS_IMAGES + + cfg = LingBotVAConfig( + attn_mode="flex", + dtype="bfloat16", + in_channels=16, + out_channels=16, + action_dim=8, + text_dim=32, + freq_dim=64, + ffn_dim=64, + num_attention_heads=2, + attention_head_dim=24, + num_layers=2, + frame_chunk_size=2, + action_per_frame=4, + used_action_channel_ids=[0, 1, 2, 3], + obs_cam_keys=[f"{OBS_IMAGES}.image"], + device="cuda", + ) + cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))} + cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))} + cfg.validate_features() + + policy = LingBotVAPolicy(cfg).to("cuda") + policy.train() + opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4) + + b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame + latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16) + actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16) + amask = torch.zeros(cfg.action_dim, device="cuda") + amask[cfg.used_action_channel_ids] = 1.0 + actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions) + text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16) + + loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb) + assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics) + loss.backward() + assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params()) + opt.step()