mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
make default params more aligned with paper and pretrained models
- adding possibility of freezing qwen backbone and world model - added tests for weight loading
This commit is contained in:
@@ -82,26 +82,15 @@ policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO")
|
|||||||
|
|
||||||
Key parameters in `VLAJEPAConfig`:
|
Key parameters in `VLAJEPAConfig`:
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
| -------------------------------------------- | ---------------------------------- | ------------------------------------------------------------------- |
|
| ------------------------- | ------- | -------------------------------------------------------------- |
|
||||||
| `qwen_model_name` | `"Qwen/Qwen3-VL-2B-Instruct"` | Qwen3-VL backbone variant |
|
| `chunk_size` | 16 | Number of actions predicted per inference call |
|
||||||
| `jepa_encoder_name` | `"facebook/vjepa2-vitl-fpc64-256"` | V-JEPA2 video encoder |
|
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
|
||||||
| `chunk_size` | 16 | Number of actions predicted per inference call |
|
| `num_video_frames` | 16 | Video clip length fed to the world model |
|
||||||
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
|
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||||
| `num_video_frames` | 16 | Video clip length fed to the world model |
|
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||||
| `jepa_tubelet_size` | 2 | Temporal patch size of the video encoder (must match encoder) |
|
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
|
||||||
| `action_model_type` | `"DiT-B"` | DiT preset — controls hidden dim, heads, head dim |
|
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||||
| `action_hidden_size` | 1024 | DiT output projection size (and action decoder input size) |
|
|
||||||
| `action_num_layers` | 12 | Number of DiT transformer blocks |
|
|
||||||
| `num_target_vision_tokens` | 32 | Learned future-vision query tokens prepended to the action sequence |
|
|
||||||
| `action_max_seq_len` | 1024 | Max length of the positional embedding table in the action head |
|
|
||||||
| `num_action_tokens_per_timestep` | 4 | Special action tokens per temporal step (used for WM conditioning) |
|
|
||||||
| `num_embodied_action_tokens_per_instruction` | 8 | Instruction-level embodied tokens appended to the Qwen sequence |
|
|
||||||
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
|
|
||||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
|
||||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
|
||||||
| `predictor_depth` | 6 | Number of transformer blocks in the video predictor |
|
|
||||||
| `repeated_diffusion_steps` | 4 | Independent noise draws per batch item (CogACT-style augmentation) |
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -120,10 +109,19 @@ lerobot-train \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
policy.path=lerobot/VLA-JEPA-LIBERO \
|
policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
dataset.repo_id=your_org/your_dataset
|
dataset.repo_id=your_org/your_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
dataset.repo_id=your_org/your_dataset \
|
||||||
|
policy.freeze_qwen=true
|
||||||
|
```
|
||||||
|
|
||||||
### Reproducing the LIBERO results
|
### Reproducing the LIBERO results
|
||||||
|
|
||||||
**Training on LIBERO:**
|
**Training on LIBERO:**
|
||||||
@@ -138,14 +136,6 @@ TODO(Maxime):
|
|||||||
lerobot-train \
|
lerobot-train \
|
||||||
policy.path=lerobot/VLA-JEPA-Pretrain \
|
policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
dataset.repo_id=lerobot/libero_10 \
|
dataset.repo_id=lerobot/libero_10 \
|
||||||
policy.chunk_size=7 \
|
|
||||||
policy.n_action_steps=7 \
|
|
||||||
policy.future_action_window_size=6 \
|
|
||||||
policy.num_video_frames=8 \
|
|
||||||
policy.num_action_tokens_per_timestep=8 \
|
|
||||||
policy.num_embodied_action_tokens_per_instruction=32 \
|
|
||||||
policy.action_num_layers=16 \
|
|
||||||
policy.predictor_depth=12 \
|
|
||||||
training.num_steps=50000 \
|
training.num_steps=50000 \
|
||||||
env.type=libero \
|
env.type=libero \
|
||||||
env.task=libero_10
|
env.task=libero_10
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
|||||||
@dataclass
|
@dataclass
|
||||||
class VLAJEPAConfig(PreTrainedConfig):
|
class VLAJEPAConfig(PreTrainedConfig):
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
chunk_size: int = 16
|
chunk_size: int = 7
|
||||||
n_action_steps: int = 16
|
n_action_steps: int = 7
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -25,26 +25,26 @@ class VLAJEPAConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||||
|
freeze_qwen: bool = False
|
||||||
|
enable_world_model: bool = True
|
||||||
|
|
||||||
tokenizer_padding_side: str = "left"
|
tokenizer_padding_side: str = "left"
|
||||||
prompt_template: str = (
|
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||||
"{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
|
|
||||||
)
|
|
||||||
special_action_token: str = "<|action_{}|>"
|
special_action_token: str = "<|action_{}|>"
|
||||||
embodied_action_token: str = "<|embodied_action|>"
|
embodied_action_token: str = "<|embodied_action|>"
|
||||||
|
|
||||||
action_dim: int = 7
|
action_dim: int = 7
|
||||||
state_dim: int = 8
|
state_dim: int = 8
|
||||||
future_action_window_size: int = 15
|
future_action_window_size: int = 6
|
||||||
past_action_window_size: int = 0
|
past_action_window_size: int = 0
|
||||||
num_action_tokens_per_timestep: int = 4
|
num_action_tokens_per_timestep: int = 8
|
||||||
num_embodied_action_tokens_per_instruction: int = 8
|
num_embodied_action_tokens_per_instruction: int = 32
|
||||||
num_inference_timesteps: int = 10
|
num_inference_timesteps: int = 4
|
||||||
|
|
||||||
action_hidden_size: int = 1024
|
action_hidden_size: int = 1024
|
||||||
action_model_type: str = "DiT-B"
|
action_model_type: str = "DiT-B"
|
||||||
action_num_layers: int = 12
|
action_num_layers: int = 16
|
||||||
action_dropout: float = 0.1
|
action_dropout: float = 0.2
|
||||||
action_num_timestep_buckets: int = 1000
|
action_num_timestep_buckets: int = 1000
|
||||||
action_noise_beta_alpha: float = 1.5
|
action_noise_beta_alpha: float = 1.5
|
||||||
action_noise_beta_beta: float = 1.0
|
action_noise_beta_beta: float = 1.0
|
||||||
@@ -53,15 +53,14 @@ class VLAJEPAConfig(PreTrainedConfig):
|
|||||||
action_max_seq_len: int = 1024
|
action_max_seq_len: int = 1024
|
||||||
|
|
||||||
# total video frames loaded per sample
|
# total video frames loaded per sample
|
||||||
num_video_frames: int = 16
|
num_video_frames: int = 8
|
||||||
predictor_depth: int = 6
|
predictor_depth: int = 12
|
||||||
predictor_num_heads: int = 8
|
predictor_num_heads: int = 8
|
||||||
predictor_mlp_ratio: float = 4.0
|
predictor_mlp_ratio: float = 4.0
|
||||||
predictor_dropout: float = 0.0
|
predictor_dropout: float = 0.0
|
||||||
world_model_loss_weight: float = 0.1
|
world_model_loss_weight: float = 0.1
|
||||||
enable_world_model: bool = True
|
|
||||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||||
repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style)
|
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||||
|
|
||||||
resize_images_to: tuple[int, int] | None = None
|
resize_images_to: tuple[int, int] | None = None
|
||||||
torch_dtype: str = "bfloat16"
|
torch_dtype: str = "bfloat16"
|
||||||
@@ -77,6 +76,9 @@ class VLAJEPAConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
if self.freeze_qwen and self.enable_world_model:
|
||||||
|
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||||
|
self.enable_world_model = False
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||||
if self.future_action_window_size + 1 > self.chunk_size:
|
if self.future_action_window_size + 1 > self.chunk_size:
|
||||||
|
|||||||
@@ -86,6 +86,9 @@ class VLAJEPAModel(nn.Module):
|
|||||||
self.video_processor = None
|
self.video_processor = None
|
||||||
self.video_predictor = None
|
self.video_predictor = None
|
||||||
|
|
||||||
|
if config.freeze_qwen:
|
||||||
|
self.qwen.requires_grad_(False)
|
||||||
|
|
||||||
# Build prompt placeholders.
|
# Build prompt placeholders.
|
||||||
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
|
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
|
||||||
# This matches the number of context temporal positions after tubelet compression.
|
# This matches the number of context temporal positions after tubelet compression.
|
||||||
|
|||||||
@@ -36,6 +36,11 @@ from lerobot.utils.constants import ACTION # noqa: E402
|
|||||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||||
|
|
||||||
|
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||||
|
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||||
|
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||||
|
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Core training / inference tests
|
# Core training / inference tests
|
||||||
@@ -259,39 +264,94 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) ->
|
|||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pretrained checkpoint
|
# Pretrained checkpoint
|
||||||
|
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def test_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||||
import torch
|
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||||
from huggingface_hub import hf_hub_download
|
cfg = policy.config
|
||||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||||
|
for key, feat in cfg.image_features.items():
|
||||||
|
h, w = feat.shape[-2], feat.shape[-1]
|
||||||
|
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||||
|
if cfg.robot_state_feature is not None:
|
||||||
|
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||||
|
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||||
|
return batch
|
||||||
|
|
||||||
repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID)
|
|
||||||
subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/")
|
|
||||||
checkpoint_filename = os.environ.get(
|
|
||||||
"VLA_JEPA_PRETRAINED_CHECKPOINT",
|
|
||||||
f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||||
checkpoint_path = hf_hub_download(
|
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||||
repo_id=repo_id, filename=checkpoint_filename, local_files_only=True
|
cfg = policy.config
|
||||||
)
|
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||||
except LocalEntryNotFoundError:
|
for key, feat in cfg.image_features.items():
|
||||||
pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.")
|
h, w = feat.shape[-2], feat.shape[-1]
|
||||||
|
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||||
|
if cfg.robot_state_feature is not None:
|
||||||
|
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||||
|
return batch
|
||||||
|
|
||||||
try:
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
|
||||||
except TypeError:
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = (
|
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
|
||||||
checkpoint.get("state_dict")
|
|
||||||
or checkpoint.get("model_state_dict")
|
# Each tuple: (repo_id, enable_world_model)
|
||||||
or checkpoint.get("model")
|
_HUB_VARIANTS = [
|
||||||
or checkpoint
|
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||||
)
|
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||||
assert isinstance(state_dict, dict)
|
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||||
assert len(state_dict) > 0
|
]
|
||||||
assert all(isinstance(k, str) for k in list(state_dict)[:10])
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||||
|
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||||
|
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||||
|
assert policy.config.enable_world_model == enable_world_model
|
||||||
|
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||||
|
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||||
|
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
batch = _make_hub_train_batch(policy)
|
||||||
|
loss, logs = policy.forward(batch)
|
||||||
|
assert torch.isfinite(loss)
|
||||||
|
assert "action_loss" in logs
|
||||||
|
if enable_world_model:
|
||||||
|
assert "wm_loss" in logs
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||||
|
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||||
|
assert not policy.config.enable_world_model
|
||||||
|
assert policy.model.video_predictor is None
|
||||||
|
qwen_params = list(policy.model.qwen.parameters())
|
||||||
|
assert all(not p.requires_grad for p in qwen_params)
|
||||||
|
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||||
|
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||||
|
assert not policy.config.enable_world_model
|
||||||
|
assert policy.model.video_predictor is None
|
||||||
|
assert policy.model.video_encoder is None
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
def test_hub_libero_inference_shape() -> None:
|
||||||
|
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||||
|
policy.eval()
|
||||||
|
batch = _make_hub_inference_batch(policy)
|
||||||
|
action = policy.select_action(batch)
|
||||||
|
assert action.shape[-1] == policy.config.action_dim
|
||||||
|
|||||||
Reference in New Issue
Block a user