mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-12 05:59:53 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c1332ac37e | |||
| 31ddb8f493 | |||
| 877847c90e |
+3
-3
@@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -231,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
|
||||
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(
|
||||
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||
) -> None:
|
||||
state: dict = {"step": step}
|
||||
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||
# compute_sampler_state).
|
||||
if num_processes is not None:
|
||||
state["num_processes"] = num_processes
|
||||
if batch_size is not None:
|
||||
state["batch_size"] = batch_size
|
||||
write_json(state, save_dir / TRAINING_STEP)
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||
|
||||
|
||||
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
@@ -96,8 +75,6 @@ def save_checkpoint(
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -123,10 +100,6 @@ def save_checkpoint(
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
@@ -139,9 +112,7 @@ def save_checkpoint(
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(
|
||||
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||
)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
@@ -149,8 +120,6 @@ def save_training_state(
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -162,12 +131,10 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
|
||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -82,7 +82,6 @@ __all__ = [
|
||||
"aggregate_stats",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"compute_sampler_state",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
|
||||
+32
-122
@@ -14,36 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||
|
||||
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||
instead of materializing a Python list of every frame index.
|
||||
|
||||
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||
resumed epoch, `__len__` still reports the full length.
|
||||
|
||||
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||
starts (e.g. on resume or in tests).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
@@ -52,125 +30,57 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
):
|
||||
"""
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index of each episode in the dataset.
|
||||
dataset_to_indices: End index of each episode in the dataset.
|
||||
episode_indices_to_use: Episode indices to use; None means all.
|
||||
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
seed: Seed the permutation is derived from (together with the epoch).
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||
if from_indices.shape != to_indices.shape:
|
||||
raise ValueError(
|
||||
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||
f"got {len(from_indices)} and {len(to_indices)}"
|
||||
)
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
used = np.ones(len(from_indices), dtype=bool)
|
||||
if episode_indices_to_use is not None:
|
||||
used = np.zeros(len(from_indices), dtype=bool)
|
||||
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||
|
||||
starts = from_indices + drop_n_first_frames
|
||||
lengths = to_indices - drop_n_last_frames - starts
|
||||
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
to_indices[episode_idx] - from_indices[episode_idx],
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
used &= lengths > 0
|
||||
if not used.any():
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self._starts = starts[used]
|
||||
self._cum_lengths = np.cumsum(lengths[used])
|
||||
self._num_frames = int(self._cum_lengths[-1])
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
self._epoch = state["epoch"]
|
||||
self._start_index = state["start_index"]
|
||||
|
||||
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||
# and reproduces identically on every rank without touching the global RNG.
|
||||
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||
return torch.Generator().manual_seed(epoch_seed)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
return self._iter_epoch(epoch, start)
|
||||
|
||||
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(int(order[k]))
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(k)
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_frames
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||
|
||||
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
|
||||
Assumptions (resume is only sample-exact when they hold):
|
||||
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||
the boundary is off.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||
return {"epoch": epoch, "start_index": start_index}
|
||||
return len(self.indices)
|
||||
|
||||
@@ -17,12 +17,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -55,12 +53,13 @@ class VLAJEPAModel(nn.Module):
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
Inputs are batched tensors kept on the model device
|
||||
- images: List[List[Tensor [C, H, W]]] (float [0,1]) — per sample, per view (Qwen messages)
|
||||
- instructions: List[str]
|
||||
- videos: Tensor [B, V, T, C, H, W] (float [0,1], world model only)
|
||||
- actions: Tensor [B, T, action_dim] (optional, training only)
|
||||
- state: Tensor [B, 1, state_dim] (optional)
|
||||
- action_is_pad: Tensor [B, T] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
@@ -161,166 +160,123 @@ class VLAJEPAModel(nn.Module):
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
def _encode_qwen(
|
||||
self, images: list[list[Tensor]], instructions: list[str], *, need_action_tokens: bool
|
||||
) -> tuple[Tensor, Tensor, Tensor | None]:
|
||||
"""Run Qwen and gather the embodied-action (and optionally action) token hidden states."""
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
images=images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
input_ids = qwen_inputs["input_ids"]
|
||||
embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True)
|
||||
action_idx = None
|
||||
if need_action_tokens:
|
||||
action_mask = torch.isin(input_ids, torch.tensor(self.action_token_ids, device=input_ids.device))
|
||||
action_idx = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
|
||||
action_tokens = (
|
||||
last_hidden[action_idx[0], action_idx[1], :].view(b, -1, h)
|
||||
if action_idx is not None
|
||||
else None
|
||||
)
|
||||
return last_hidden, embodied_action_tokens, action_tokens
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
def _world_model_loss(self, videos: Tensor, action_tokens: Tensor) -> Tensor:
|
||||
"""JEPA encode + predictor L1 loss. `videos` is [B, V, T, C, H, W] float in [0, 1]."""
|
||||
# Match the world model's expected view count: pad with the first view, or trim extras.
|
||||
num_views = self.config.jepa_tubelet_size
|
||||
if videos.shape[1] < num_views:
|
||||
missing = num_views - videos.shape[1]
|
||||
videos = torch.cat([videos, videos[:, :1].repeat(1, missing, 1, 1, 1, 1)], dim=1)
|
||||
elif videos.shape[1] > num_views:
|
||||
videos = videos[:, :num_views]
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
b, v, t_frames, c, h_img, w_img = videos.shape
|
||||
flat = videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
# Fast (torchvision) video processor on-device, do_rescale=False (frames already in [0, 1]).
|
||||
video_pixels = self.video_processor(
|
||||
videos=list(flat),
|
||||
return_tensors="pt",
|
||||
device=self.video_encoder.device,
|
||||
do_rescale=False,
|
||||
)["pixel_values_videos"] # [B*V, T, C, H, W]
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
if t_enc_total < 2:
|
||||
return torch.zeros((), device=video_embeddings.device)
|
||||
|
||||
# Shift-by-one JEPA split: input_states = positions 0..T-2, gt_states = positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(), action_tokens[:, :expected_actions].float()
|
||||
)
|
||||
return F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
def _action_loss(
|
||||
self,
|
||||
embodied_action_tokens: Tensor,
|
||||
actions: Tensor,
|
||||
state: Tensor | None,
|
||||
action_is_pad: Tensor | None,
|
||||
) -> Tensor:
|
||||
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
|
||||
device_type = next(self.parameters()).device.type
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
r = self.config.repeated_diffusion_steps
|
||||
horizon = self.config.chunk_size
|
||||
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
|
||||
embodied = embodied_action_tokens.repeat(r, 1, 1)
|
||||
state_rep = state.to(embodied_action_tokens.dtype).repeat(r, 1, 1) if state is not None else None
|
||||
pad_rep = action_is_pad[:, -horizon:].repeat(r, 1) if action_is_pad is not None else None
|
||||
return self.action_model(embodied, actions_target, state_rep, pad_rep)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: list[list[Tensor]],
|
||||
instructions: list[str],
|
||||
videos: Tensor | None = None,
|
||||
actions: Tensor | None = None,
|
||||
state: Tensor | None = None,
|
||||
action_is_pad: Tensor | None = None,
|
||||
) -> dict[str, Tensor]:
|
||||
"""Native forward: Qwen encode → optional world-model loss → optional action-head loss."""
|
||||
last_hidden, embodied_action_tokens, action_tokens = self._encode_qwen(
|
||||
images, instructions, need_action_tokens=self.config.enable_world_model
|
||||
)
|
||||
|
||||
if self.config.enable_world_model:
|
||||
wm_loss = self._world_model_loss(videos, action_tokens)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
wm_loss = torch.zeros((), device=last_hidden.device)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
if actions is None:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
action_loss = self._action_loss(embodied_action_tokens, actions, state, action_is_pad)
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
@@ -328,58 +284,24 @@ class VLAJEPAModel(nn.Module):
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
images: list[list[Tensor]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
state: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Predict an action chunk. `images` is per-sample, per-view float [0,1] [C, H, W] tensors."""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
images = [
|
||||
[F.interpolate(img[None], size=(height, width), mode="area")[0] for img in views]
|
||||
for views in images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
_, embodied_action_tokens, _ = self._encode_qwen(images, instructions, need_action_tokens=False)
|
||||
state = state.to(embodied_action_tokens.dtype) if state is not None else None
|
||||
return self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state.float() if state is not None else None
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
@@ -390,9 +312,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the batched tensors
|
||||
the native model expects (keeping everything on-device), calls the native model, and
|
||||
converts outputs back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
@@ -419,9 +341,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Any]:
|
||||
"""Convert a LeRobot batch to the model's batched, on-device inputs.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
@@ -431,65 +352,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
Returns the kwargs for `VLAJEPAModel.forward` / `.predict_action` (everything stays
|
||||
on the batch device; no per-sample shredding): `images` (per-sample, per-view list for
|
||||
Qwen messages), `instructions`, and the batched `videos` / `actions` / `state` /
|
||||
`action_is_pad` when present.
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
batch_size = batch[image_keys[0]].shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
# Current-frame image per view ([B, C, H, W]); regroup per sample for Qwen messages.
|
||||
frames = []
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
t = batch[key]
|
||||
if t.ndim == 5: # [B, T, C, H, W] -> current observation (delta=0)
|
||||
t = t[:, 0]
|
||||
frames.append(self.model.qwen.to_pixel_values(t))
|
||||
images = [[frame[b] for frame in frames] for b in range(batch_size)]
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
@@ -498,52 +379,32 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
inputs: dict[str, Any] = {"images": images, "instructions": instructions}
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
# Videos [B, V, T, C, H, W] - only assembled when the world model consumes them.
|
||||
if self.model.config.enable_world_model:
|
||||
views = [batch[k].unsqueeze(1) if batch[k].ndim == 4 else batch[k] for k in image_keys]
|
||||
inputs["videos"] = self.model.qwen.to_pixel_values(torch.stack(views, dim=1))
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
actions = batch.get(ACTION)
|
||||
if actions is not None:
|
||||
inputs["actions"] = (actions.unsqueeze(1) if actions.ndim == 2 else actions).float()
|
||||
if (pad := batch.get("action_is_pad")) is not None:
|
||||
inputs["action_is_pad"] = pad
|
||||
|
||||
return examples
|
||||
state = batch.get(OBS_STATE)
|
||||
if state is not None:
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
inputs["state"] = (state.unsqueeze(1) if state.ndim == 2 else state).float() # [B, 1, dim]
|
||||
|
||||
return inputs
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
native_output = self.model.forward(**self._prepare_model_inputs(batch))
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
@@ -561,16 +422,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
inputs = self._prepare_model_inputs(batch)
|
||||
actions = self.model.predict_action(inputs["images"], inputs["instructions"], inputs.get("state"))
|
||||
return actions.to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
|
||||
@@ -17,9 +17,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -78,7 +76,7 @@ class Qwen3VLInterface(torch.nn.Module):
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
images: Sequence[Sequence[torch.Tensor]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
@@ -94,24 +92,42 @@ class Qwen3VLInterface(torch.nn.Module):
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
# The Qwen image processor is a torchvision-backed fast processor: passing the
|
||||
# images as GPU tensors (with `device`) keeps the whole vision pipeline on-device
|
||||
# and avoids a GPU->CPU->GPU roundtrip. The image tensors are forwarded through
|
||||
# apply_chat_template untouched into Qwen3VLProcessor.__call__.
|
||||
# do_rescale=False: images already arrive as float in [0, 1] (the dataset decoder
|
||||
# yields float32/255 and VISUAL normalization is IDENTITY), so we skip the
|
||||
# processor's /255 rescale instead of round-tripping through uint8.
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
processor_kwargs={
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
"device": self.model.device,
|
||||
"do_rescale": False,
|
||||
},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
def to_pixel_values(image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Prepare an image/video tensor for the fast processors (used with do_rescale=False).
|
||||
|
||||
The dataset decoder yields float32 in [0, 1] (channels-first) and VISUAL
|
||||
normalization is IDENTITY, so the tensor already arrives in [0, 1]; we pass it
|
||||
through as float and let the processors normalize (no rescale, no uint8
|
||||
quantization). A single channel is expanded to 3 to match the RGB processors.
|
||||
|
||||
Works for any channels-first layout (channel dim is -3): [C, H, W], [B, C, H, W],
|
||||
[T, C, H, W], [B, V, T, C, H, W], ...
|
||||
"""
|
||||
image = image_tensor.detach().float()
|
||||
if image.shape[-3] == 1:
|
||||
repeats = [1] * image.ndim
|
||||
repeats[-3] = 3
|
||||
image = image.repeat(*repeats)
|
||||
return image
|
||||
|
||||
@@ -36,8 +36,6 @@ from tqdm import tqdm
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
@@ -45,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -234,16 +232,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: the global main process downloads once to the shared
|
||||
# dataset root, then a barrier lets every other rank read the already-populated copy.
|
||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Other ranks read from the shared copy populated by the main process.
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
@@ -388,47 +384,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if not cfg.dataset.streaming:
|
||||
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
|
||||
# The order is a pure function of (seed, epoch), so every rank independently produces the
|
||||
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
|
||||
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||
# use the values recorded in the checkpoint (falling back to the current ones for older
|
||||
# ckpts that did not store them).
|
||||
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
|
||||
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
|
||||
ckpt_num_processes = saved_num_processes or accelerator.num_processes
|
||||
ckpt_batch_size = saved_batch_size or cfg.batch_size
|
||||
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
|
||||
logging.warning(
|
||||
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
|
||||
f"written with num_processes={saved_num_processes}. The data order resumes at the "
|
||||
"right epoch/offset, but per-rank sample-exactness requires the same world size."
|
||||
)
|
||||
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
|
||||
logging.warning(
|
||||
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
|
||||
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
|
||||
"but per-rank sample-exactness requires the same batch size."
|
||||
)
|
||||
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
|
||||
sampler.load_state_dict(sampler_state)
|
||||
if is_main_process:
|
||||
logging.info(
|
||||
f"Resuming data order at epoch {sampler_state['epoch']}, "
|
||||
f"sample {sampler_state['start_index']}"
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -547,8 +511,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
@@ -114,19 +114,6 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_is_reproducible_across_instances():
|
||||
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
|
||||
# produce the same permutation without any generator synchronization.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == epoch_0
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
@@ -150,87 +137,3 @@ def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- seeded (seed, epoch) shuffling, resume, and state ---
|
||||
|
||||
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
|
||||
|
||||
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
|
||||
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
|
||||
for seed in (0, 1, 1234):
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
|
||||
assert sorted(sampler) == list(range(num_frames))
|
||||
|
||||
|
||||
def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
|
||||
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
|
||||
assert epoch_1 != epoch_0
|
||||
assert sorted(epoch_1) == sorted(epoch_0)
|
||||
sampler_a.set_epoch(0)
|
||||
assert list(sampler_a) == epoch_0
|
||||
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_mid_epoch():
|
||||
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
epoch_0 = list(reference)
|
||||
epoch_1 = list(reference)
|
||||
for start in (0, 1, 4, len(epoch_0)):
|
||||
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
|
||||
assert list(resumed) == epoch_1
|
||||
|
||||
|
||||
def test_deterministic_sampler_construction_stores_only_boundaries():
|
||||
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
|
||||
# instantiates from just its boundaries without materializing a per-frame index list.
|
||||
num_frames = 1_000_000
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
assert len(sampler) == num_frames
|
||||
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_is_exact_at_scale():
|
||||
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
|
||||
# permutation and slicing from the saved offset reproduces the remaining order exactly.
|
||||
num_frames = 100_000
|
||||
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
epoch_0 = list(reference)
|
||||
assert sorted(epoch_0) == list(range(num_frames))
|
||||
start = num_frames - 5
|
||||
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
|
||||
|
||||
def test_compute_sampler_state():
|
||||
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
|
||||
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 0,
|
||||
"start_index": 0,
|
||||
}
|
||||
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
|
||||
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
|
||||
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
|
||||
"epoch": 2,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
|
||||
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 100,
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ from types import SimpleNamespace
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -191,7 +190,7 @@ class _FakeQwenInterface(nn.Module):
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
images: list[list[Tensor]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
@@ -214,12 +213,13 @@ class _FakeQwenInterface(nn.Module):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
def to_pixel_values(image_tensor: Tensor) -> Tensor:
|
||||
image = image_tensor.detach().float()
|
||||
if image.shape[-3] == 1:
|
||||
repeats = [1] * image.ndim
|
||||
repeats[-3] = 3
|
||||
image = image.repeat(*repeats)
|
||||
return image
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
@@ -242,12 +242,14 @@ class _FakeVideoEncoder(nn.Module):
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
||||
def __call__(self, videos, return_tensors: str, device=None, **kwargs) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
if isinstance(videos, list):
|
||||
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||
else:
|
||||
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||
if device is not None:
|
||||
pixel_values = pixel_values.to(device)
|
||||
return {"pixel_values_videos": pixel_values}
|
||||
|
||||
|
||||
|
||||
@@ -211,40 +211,42 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
|
||||
|
||||
|
||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||
from PIL import Image
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
examples = policy._prepare_model_inputs(make_train_batch())
|
||||
inputs = policy._prepare_model_inputs(make_train_batch())
|
||||
|
||||
assert len(examples) == BATCH_SIZE
|
||||
for ex in examples:
|
||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||
assert ex["state"].shape == (1, STATE_DIM)
|
||||
assert set(inputs) >= {"images", "instructions", "videos", "actions", "state"}
|
||||
# images: per-sample, per-view [C, H, W] float tensors (kept as a list for Qwen messages)
|
||||
assert len(inputs["images"]) == BATCH_SIZE and len(inputs["images"][0]) == 1
|
||||
img = inputs["images"][0][0]
|
||||
assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 and img.ndim == 3
|
||||
assert len(inputs["instructions"]) == BATCH_SIZE
|
||||
# videos: batched [B, V, T, C, H, W] float
|
||||
assert inputs["videos"].ndim == 6 and inputs["videos"].shape[0] == BATCH_SIZE
|
||||
assert inputs["videos"].dtype == torch.float32
|
||||
assert inputs["actions"].shape == (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
assert inputs["state"].shape == (BATCH_SIZE, 1, STATE_DIM)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||
assert "action" not in ex
|
||||
assert "image" in ex and "video" in ex and "lang" in ex
|
||||
inputs = policy._prepare_model_inputs(make_inference_batch())
|
||||
assert "actions" not in inputs and "action_is_pad" not in inputs
|
||||
assert {"images", "instructions", "state"} <= set(inputs)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch["task"]
|
||||
examples = policy._prepare_model_inputs(batch)
|
||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||
instructions = policy._prepare_model_inputs(batch)["instructions"]
|
||||
assert all(isinstance(s, str) and len(s) > 0 for s in instructions)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
batch["task"] = "open the drawer"
|
||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||
assert policy._prepare_model_inputs(batch)["instructions"] == ["open the drawer"] * BATCH_SIZE
|
||||
|
||||
|
||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||
@@ -253,7 +255,7 @@ def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: N
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch[OBS_STATE]
|
||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||
assert "state" not in policy._prepare_model_inputs(batch)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -446,14 +448,14 @@ def test_postprocessor_applied_after_predict_action_chunk(
|
||||
"""
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||
raw_actions = torch.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=torch.float32)
|
||||
|
||||
cfg = make_config()
|
||||
cfg.clip_normalized_actions = False
|
||||
cfg.binarize_gripper_action = False
|
||||
policy = VLAJEPAPolicy(cfg)
|
||||
policy.eval()
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.clone())
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||
@@ -564,9 +566,9 @@ def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_model
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=1))
|
||||
@@ -587,9 +589,9 @@ def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: No
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=3))
|
||||
|
||||
@@ -20,8 +20,6 @@ from unittest.mock import Mock, patch
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
@@ -65,28 +63,6 @@ def test_load_training_step(tmp_path):
|
||||
assert loaded_step == step
|
||||
|
||||
|
||||
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
|
||||
assert load_training_num_processes(tmp_path) == 4
|
||||
|
||||
|
||||
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the world size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_num_processes(tmp_path) is None
|
||||
|
||||
|
||||
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
|
||||
assert load_training_batch_size(tmp_path) == 32
|
||||
|
||||
|
||||
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the batch size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_batch_size(tmp_path) is None
|
||||
|
||||
|
||||
def test_update_last_checkpoint(tmp_path):
|
||||
checkpoint = tmp_path / "0005"
|
||||
checkpoint.mkdir()
|
||||
|
||||
@@ -1764,7 +1764,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dm-control" },
|
||||
@@ -1772,14 +1772,14 @@ dependencies = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "mujoco" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-hil"
|
||||
version = "0.1.14"
|
||||
version = "0.1.13"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gymnasium" },
|
||||
@@ -1789,9 +1789,9 @@ dependencies = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pynput" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1881,7 +1881,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
|
||||
|
||||
[[package]]
|
||||
name = "hf-libero"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bddl", marker = "sys_platform == 'linux'" },
|
||||
@@ -1902,10 +1902,7 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'linux'" },
|
||||
{ name = "wandb", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
@@ -3093,12 +3090,12 @@ requires-dist = [
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
|
||||
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
|
||||
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user