mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 02:29:47 +00:00
make default params more aligned with paper and pretrained models
- adding possibility of freezing qwen backbone and world model - added tests for weight loading
This commit is contained in:
@@ -82,26 +82,15 @@ policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO")
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------------- | ---------------------------------- | ------------------------------------------------------------------- |
|
||||
| `qwen_model_name` | `"Qwen/Qwen3-VL-2B-Instruct"` | Qwen3-VL backbone variant |
|
||||
| `jepa_encoder_name` | `"facebook/vjepa2-vitl-fpc64-256"` | V-JEPA2 video encoder |
|
||||
| `chunk_size` | 16 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 16 | Video clip length fed to the world model |
|
||||
| `jepa_tubelet_size` | 2 | Temporal patch size of the video encoder (must match encoder) |
|
||||
| `action_model_type` | `"DiT-B"` | DiT preset — controls hidden dim, heads, head dim |
|
||||
| `action_hidden_size` | 1024 | DiT output projection size (and action decoder input size) |
|
||||
| `action_num_layers` | 12 | Number of DiT transformer blocks |
|
||||
| `num_target_vision_tokens` | 32 | Learned future-vision query tokens prepended to the action sequence |
|
||||
| `action_max_seq_len` | 1024 | Max length of the positional embedding table in the action head |
|
||||
| `num_action_tokens_per_timestep` | 4 | Special action tokens per temporal step (used for WM conditioning) |
|
||||
| `num_embodied_action_tokens_per_instruction` | 8 | Instruction-level embodied tokens appended to the Qwen sequence |
|
||||
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `predictor_depth` | 6 | Number of transformer blocks in the video predictor |
|
||||
| `repeated_diffusion_steps` | 4 | Independent noise draws per batch item (CogACT-style augmentation) |
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | -------------------------------------------------------------- |
|
||||
| `chunk_size` | 16 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 16 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
|
||||
---
|
||||
|
||||
@@ -120,10 +109,19 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
dataset.repo_id=your_org/your_dataset \
|
||||
policy.freeze_qwen=true
|
||||
```
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
@@ -138,14 +136,6 @@ TODO(Maxime):
|
||||
lerobot-train \
|
||||
policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
dataset.repo_id=lerobot/libero_10 \
|
||||
policy.chunk_size=7 \
|
||||
policy.n_action_steps=7 \
|
||||
policy.future_action_window_size=6 \
|
||||
policy.num_video_frames=8 \
|
||||
policy.num_action_tokens_per_timestep=8 \
|
||||
policy.num_embodied_action_tokens_per_instruction=32 \
|
||||
policy.action_num_layers=16 \
|
||||
policy.predictor_depth=12 \
|
||||
training.num_steps=50000 \
|
||||
env.type=libero \
|
||||
env.task=libero_10
|
||||
|
||||
@@ -12,8 +12,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 16
|
||||
n_action_steps: int = 16
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -25,26 +25,26 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = (
|
||||
"{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
|
||||
)
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
future_action_window_size: int = 15
|
||||
future_action_window_size: int = 6
|
||||
past_action_window_size: int = 0
|
||||
num_action_tokens_per_timestep: int = 4
|
||||
num_embodied_action_tokens_per_instruction: int = 8
|
||||
num_inference_timesteps: int = 10
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 12
|
||||
action_dropout: float = 0.1
|
||||
action_num_layers: int = 16
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
@@ -53,15 +53,14 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 16
|
||||
predictor_depth: int = 6
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
enable_world_model: bool = True
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -77,6 +76,9 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.future_action_window_size + 1 > self.chunk_size:
|
||||
|
||||
@@ -86,6 +86,9 @@ class VLAJEPAModel(nn.Module):
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
|
||||
# This matches the number of context temporal positions after tubelet compression.
|
||||
|
||||
@@ -36,6 +36,11 @@ from lerobot.utils.constants import ACTION # noqa: E402
|
||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||
|
||||
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core training / inference tests
|
||||
@@ -259,39 +264,94 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) ->
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretrained checkpoint
|
||||
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||
return batch
|
||||
|
||||
repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID)
|
||||
subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/")
|
||||
checkpoint_filename = os.environ.get(
|
||||
"VLA_JEPA_PRETRAINED_CHECKPOINT",
|
||||
f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt",
|
||||
)
|
||||
|
||||
try:
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=repo_id, filename=checkpoint_filename, local_files_only=True
|
||||
)
|
||||
except LocalEntryNotFoundError:
|
||||
pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.")
|
||||
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||
return batch
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
state_dict = (
|
||||
checkpoint.get("state_dict")
|
||||
or checkpoint.get("model_state_dict")
|
||||
or checkpoint.get("model")
|
||||
or checkpoint
|
||||
)
|
||||
assert isinstance(state_dict, dict)
|
||||
assert len(state_dict) > 0
|
||||
assert all(isinstance(k, str) for k in list(state_dict)[:10])
|
||||
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
|
||||
|
||||
# Each tuple: (repo_id, enable_world_model)
|
||||
_HUB_VARIANTS = [
|
||||
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||
]
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
assert policy.config.enable_world_model == enable_world_model
|
||||
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
policy.train()
|
||||
|
||||
batch = _make_hub_train_batch(policy)
|
||||
loss, logs = policy.forward(batch)
|
||||
assert torch.isfinite(loss)
|
||||
assert "action_loss" in logs
|
||||
if enable_world_model:
|
||||
assert "wm_loss" in logs
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
qwen_params = list(policy.model.qwen.parameters())
|
||||
assert all(not p.requires_grad for p in qwen_params)
|
||||
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
assert policy.model.video_encoder is None
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_libero_inference_shape() -> None:
|
||||
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||
policy.eval()
|
||||
batch = _make_hub_inference_batch(policy)
|
||||
action = policy.select_action(batch)
|
||||
assert action.shape[-1] == policy.config.action_dim
|
||||
|
||||
Reference in New Issue
Block a user