mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(lingbot_va): CI quality gate + fast-test collection
- Add tests/policies/lingbot_va/__init__.py so the test files don't clash by basename with tests/policies/vla_jepa/* under pytest's default import mode (fast-test collection error). - Fix vendored typos flagged by the typos hook (pach_scale->patch_scale, total_tolen-> total_token_len, stablized->stabilized) and a mypy union-attr in RoboTwinEnv._read_eef_pose. - Apply Prettier formatting to docs/source/lingbot_va.mdx. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+19
-18
@@ -14,11 +14,11 @@ LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
|
||||
text conditioning. Actions are produced by the dedicated `action_proj_out` head — they are
|
||||
**not** decoded from predicted pixels, though video and action are co-trained.
|
||||
|
||||
| Component | Class | Role |
|
||||
|---|---|---|
|
||||
| Component | Class | Role |
|
||||
| ------------------------ | ----------------------- | -------------------------------------------------------------------------------------- |
|
||||
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer (the only weights stored in the LeRobot checkpoint). |
|
||||
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
|
||||
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
|
||||
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
|
||||
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
|
||||
|
||||
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
|
||||
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
|
||||
@@ -50,11 +50,11 @@ pip install -e ".[lingbot_va,libero]"
|
||||
|
||||
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
|
||||
|
||||
| Variant | LeRobot checkpoint |
|
||||
|---|---|
|
||||
| Variant | LeRobot checkpoint |
|
||||
| ---------------------- | ---------------------------------- |
|
||||
| LIBERO-Long post-train | `pepijn223/lingbot_va_libero_long` |
|
||||
| RoboTwin post-train | `pepijn223/lingbot_va_robotwin` |
|
||||
| Pretrained base | `pepijn223/lingbot_va_base` |
|
||||
| RoboTwin post-train | `pepijn223/lingbot_va_robotwin` |
|
||||
| Pretrained base | `pepijn223/lingbot_va_base` |
|
||||
|
||||
**Packaging:** only the trainable ~5B transformer is stored in the LeRobot
|
||||
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are **lazily pulled** from
|
||||
@@ -112,6 +112,7 @@ transformer's block-causal training pass and returns `(loss, metrics)`. Optimize
|
||||
with a linear-warmup-then-constant schedule (matching upstream).
|
||||
|
||||
Requirements:
|
||||
|
||||
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
|
||||
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
|
||||
- The full 5B DiT does not fit a single 24–32 GB GPU under AdamW; fine-tune with **LoRA**
|
||||
@@ -131,16 +132,16 @@ The dataset must provide camera clips (a temporal window per camera, VAE-encoded
|
||||
|
||||
## Inference Hyperparameters (LIBERO)
|
||||
|
||||
| Key | Value |
|
||||
|---|---|
|
||||
| height × width | 128 × 128 |
|
||||
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
|
||||
| action channels used | 0–6 (7-DoF arm + gripper) |
|
||||
| action_per_frame / frame_chunk_size | 4 / 4 |
|
||||
| attn_window | 30 |
|
||||
| video / action denoising steps | 20 / 50 |
|
||||
| guidance_scale / action_guidance_scale | 5 / 1 |
|
||||
| snr_shift / action_snr_shift | 5.0 / 0.05 |
|
||||
| Key | Value |
|
||||
| -------------------------------------- | --------------------------------------------------------------------------------- |
|
||||
| height × width | 128 × 128 |
|
||||
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
|
||||
| action channels used | 0–6 (7-DoF arm + gripper) |
|
||||
| action_per_frame / frame_chunk_size | 4 / 4 |
|
||||
| attn_window | 30 |
|
||||
| video / action denoising steps | 20 / 50 |
|
||||
| guidance_scale / action_guidance_scale | 5 / 1 |
|
||||
| snr_shift / action_snr_shift | 5.0 / 0.05 |
|
||||
|
||||
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
|
||||
|
||||
|
||||
@@ -359,6 +359,7 @@ class RoboTwinEnv(gym.Env):
|
||||
|
||||
def _read_eef_pose(self) -> np.ndarray:
|
||||
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
|
||||
assert self._env is not None, "_read_eef_pose called before _ensure_env()"
|
||||
ep = self._env.get_obs()["endpose"]
|
||||
pose = (
|
||||
list(ep["left_endpose"])
|
||||
|
||||
@@ -185,12 +185,12 @@ class FlowMatchScheduler:
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
def return_to_timestep(self, timestep, sample, sample_stabilized):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
model_output = (sample - sample_stabilized) / sigma
|
||||
return model_output
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep, t_dim=2):
|
||||
@@ -518,15 +518,15 @@ class WanAttention(nn.Module):
|
||||
return
|
||||
self.attn_caches[cache_name] = None
|
||||
|
||||
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size):
|
||||
def init_kv_cache(self, cache_name, total_token_len, num_head, head_dim, device, dtype, batch_size):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
self.attn_caches[cache_name] = {
|
||||
"k": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
|
||||
"v": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
|
||||
"id": torch.full((total_tolen,), -1, device=device),
|
||||
"mask": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
|
||||
"is_pred": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
|
||||
"k": torch.empty([batch_size, total_token_len, num_head, head_dim], device=device, dtype=dtype),
|
||||
"v": torch.empty([batch_size, total_token_len, num_head, head_dim], device=device, dtype=dtype),
|
||||
"id": torch.full((total_token_len,), -1, device=device),
|
||||
"mask": torch.zeros((total_token_len,), dtype=torch.bool, device=device),
|
||||
"is_pred": torch.zeros((total_token_len,), dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
def allocate_slots(self, cache_name, key_size):
|
||||
@@ -831,13 +831,13 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
dtype,
|
||||
batch_size,
|
||||
):
|
||||
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
|
||||
total_token_len = (attn_window // 2) * latent_token_per_chunk + (
|
||||
attn_window // 2
|
||||
) * action_token_per_chunk
|
||||
for block in self.blocks:
|
||||
block.attn1.init_kv_cache(
|
||||
cache_name,
|
||||
total_tolen,
|
||||
total_token_len,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
device,
|
||||
@@ -866,9 +866,9 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
return hidden_states
|
||||
|
||||
def _time_embed(self, timesteps, H, W, dtype, action_mode=False):
|
||||
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
patch_scale_h, patch_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
latent_time_steps = torch.repeat_interleave(
|
||||
timesteps, (H // pach_scale_h) * (W // pach_scale_w), dim=1
|
||||
timesteps, (H // patch_scale_h) * (W // patch_scale_w), dim=1
|
||||
)
|
||||
current_condition_embedder = (
|
||||
self.condition_embedder_action if action_mode else self.condition_embedder
|
||||
@@ -1012,12 +1012,12 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
latent_grid_id = input_dict["grid_id"]
|
||||
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
|
||||
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
patch_scale_h, patch_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
|
||||
latent_time_steps = torch.repeat_interleave(
|
||||
input_dict["timesteps"],
|
||||
(input_dict["noisy_latents"].shape[-2] // pach_scale_h)
|
||||
* (input_dict["noisy_latents"].shape[-1] // pach_scale_w),
|
||||
(input_dict["noisy_latents"].shape[-2] // patch_scale_h)
|
||||
* (input_dict["noisy_latents"].shape[-1] // patch_scale_w),
|
||||
dim=1,
|
||||
) # L
|
||||
current_condition_embedder = (
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Reference in New Issue
Block a user