mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-13 14:39:44 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c1332ac37e | |||
| 31ddb8f493 | |||
| 877847c90e | |||
| 49755a3d9e | |||
| 09808183ca |
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -17,12 +17,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -55,12 +53,13 @@ class VLAJEPAModel(nn.Module):
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
Inputs are batched tensors kept on the model device
|
||||
- images: List[List[Tensor [C, H, W]]] (float [0,1]) — per sample, per view (Qwen messages)
|
||||
- instructions: List[str]
|
||||
- videos: Tensor [B, V, T, C, H, W] (float [0,1], world model only)
|
||||
- actions: Tensor [B, T, action_dim] (optional, training only)
|
||||
- state: Tensor [B, 1, state_dim] (optional)
|
||||
- action_is_pad: Tensor [B, T] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
@@ -161,166 +160,123 @@ class VLAJEPAModel(nn.Module):
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
def _encode_qwen(
|
||||
self, images: list[list[Tensor]], instructions: list[str], *, need_action_tokens: bool
|
||||
) -> tuple[Tensor, Tensor, Tensor | None]:
|
||||
"""Run Qwen and gather the embodied-action (and optionally action) token hidden states."""
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
images=images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
input_ids = qwen_inputs["input_ids"]
|
||||
embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True)
|
||||
action_idx = None
|
||||
if need_action_tokens:
|
||||
action_mask = torch.isin(input_ids, torch.tensor(self.action_token_ids, device=input_ids.device))
|
||||
action_idx = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
|
||||
action_tokens = (
|
||||
last_hidden[action_idx[0], action_idx[1], :].view(b, -1, h)
|
||||
if action_idx is not None
|
||||
else None
|
||||
)
|
||||
return last_hidden, embodied_action_tokens, action_tokens
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
def _world_model_loss(self, videos: Tensor, action_tokens: Tensor) -> Tensor:
|
||||
"""JEPA encode + predictor L1 loss. `videos` is [B, V, T, C, H, W] float in [0, 1]."""
|
||||
# Match the world model's expected view count: pad with the first view, or trim extras.
|
||||
num_views = self.config.jepa_tubelet_size
|
||||
if videos.shape[1] < num_views:
|
||||
missing = num_views - videos.shape[1]
|
||||
videos = torch.cat([videos, videos[:, :1].repeat(1, missing, 1, 1, 1, 1)], dim=1)
|
||||
elif videos.shape[1] > num_views:
|
||||
videos = videos[:, :num_views]
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
b, v, t_frames, c, h_img, w_img = videos.shape
|
||||
flat = videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
# Fast (torchvision) video processor on-device, do_rescale=False (frames already in [0, 1]).
|
||||
video_pixels = self.video_processor(
|
||||
videos=list(flat),
|
||||
return_tensors="pt",
|
||||
device=self.video_encoder.device,
|
||||
do_rescale=False,
|
||||
)["pixel_values_videos"] # [B*V, T, C, H, W]
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
if t_enc_total < 2:
|
||||
return torch.zeros((), device=video_embeddings.device)
|
||||
|
||||
# Shift-by-one JEPA split: input_states = positions 0..T-2, gt_states = positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(), action_tokens[:, :expected_actions].float()
|
||||
)
|
||||
return F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
def _action_loss(
|
||||
self,
|
||||
embodied_action_tokens: Tensor,
|
||||
actions: Tensor,
|
||||
state: Tensor | None,
|
||||
action_is_pad: Tensor | None,
|
||||
) -> Tensor:
|
||||
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
|
||||
device_type = next(self.parameters()).device.type
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
r = self.config.repeated_diffusion_steps
|
||||
horizon = self.config.chunk_size
|
||||
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
|
||||
embodied = embodied_action_tokens.repeat(r, 1, 1)
|
||||
state_rep = state.to(embodied_action_tokens.dtype).repeat(r, 1, 1) if state is not None else None
|
||||
pad_rep = action_is_pad[:, -horizon:].repeat(r, 1) if action_is_pad is not None else None
|
||||
return self.action_model(embodied, actions_target, state_rep, pad_rep)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: list[list[Tensor]],
|
||||
instructions: list[str],
|
||||
videos: Tensor | None = None,
|
||||
actions: Tensor | None = None,
|
||||
state: Tensor | None = None,
|
||||
action_is_pad: Tensor | None = None,
|
||||
) -> dict[str, Tensor]:
|
||||
"""Native forward: Qwen encode → optional world-model loss → optional action-head loss."""
|
||||
last_hidden, embodied_action_tokens, action_tokens = self._encode_qwen(
|
||||
images, instructions, need_action_tokens=self.config.enable_world_model
|
||||
)
|
||||
|
||||
if self.config.enable_world_model:
|
||||
wm_loss = self._world_model_loss(videos, action_tokens)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
wm_loss = torch.zeros((), device=last_hidden.device)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
if actions is None:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
action_loss = self._action_loss(embodied_action_tokens, actions, state, action_is_pad)
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
@@ -328,58 +284,24 @@ class VLAJEPAModel(nn.Module):
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
images: list[list[Tensor]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
state: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Predict an action chunk. `images` is per-sample, per-view float [0,1] [C, H, W] tensors."""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
images = [
|
||||
[F.interpolate(img[None], size=(height, width), mode="area")[0] for img in views]
|
||||
for views in images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
_, embodied_action_tokens, _ = self._encode_qwen(images, instructions, need_action_tokens=False)
|
||||
state = state.to(embodied_action_tokens.dtype) if state is not None else None
|
||||
return self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state.float() if state is not None else None
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
@@ -390,9 +312,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the batched tensors
|
||||
the native model expects (keeping everything on-device), calls the native model, and
|
||||
converts outputs back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
@@ -419,9 +341,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Any]:
|
||||
"""Convert a LeRobot batch to the model's batched, on-device inputs.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
@@ -431,65 +352,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
Returns the kwargs for `VLAJEPAModel.forward` / `.predict_action` (everything stays
|
||||
on the batch device; no per-sample shredding): `images` (per-sample, per-view list for
|
||||
Qwen messages), `instructions`, and the batched `videos` / `actions` / `state` /
|
||||
`action_is_pad` when present.
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
batch_size = batch[image_keys[0]].shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
# Current-frame image per view ([B, C, H, W]); regroup per sample for Qwen messages.
|
||||
frames = []
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
t = batch[key]
|
||||
if t.ndim == 5: # [B, T, C, H, W] -> current observation (delta=0)
|
||||
t = t[:, 0]
|
||||
frames.append(self.model.qwen.to_pixel_values(t))
|
||||
images = [[frame[b] for frame in frames] for b in range(batch_size)]
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
@@ -498,52 +379,32 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
inputs: dict[str, Any] = {"images": images, "instructions": instructions}
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
# Videos [B, V, T, C, H, W] - only assembled when the world model consumes them.
|
||||
if self.model.config.enable_world_model:
|
||||
views = [batch[k].unsqueeze(1) if batch[k].ndim == 4 else batch[k] for k in image_keys]
|
||||
inputs["videos"] = self.model.qwen.to_pixel_values(torch.stack(views, dim=1))
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
actions = batch.get(ACTION)
|
||||
if actions is not None:
|
||||
inputs["actions"] = (actions.unsqueeze(1) if actions.ndim == 2 else actions).float()
|
||||
if (pad := batch.get("action_is_pad")) is not None:
|
||||
inputs["action_is_pad"] = pad
|
||||
|
||||
return examples
|
||||
state = batch.get(OBS_STATE)
|
||||
if state is not None:
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
inputs["state"] = (state.unsqueeze(1) if state.ndim == 2 else state).float() # [B, 1, dim]
|
||||
|
||||
return inputs
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
native_output = self.model.forward(**self._prepare_model_inputs(batch))
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
@@ -561,16 +422,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
inputs = self._prepare_model_inputs(batch)
|
||||
actions = self.model.predict_action(inputs["images"], inputs["instructions"], inputs.get("state"))
|
||||
return actions.to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
|
||||
@@ -17,9 +17,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -78,7 +76,7 @@ class Qwen3VLInterface(torch.nn.Module):
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
images: Sequence[Sequence[torch.Tensor]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
@@ -94,24 +92,42 @@ class Qwen3VLInterface(torch.nn.Module):
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
# The Qwen image processor is a torchvision-backed fast processor: passing the
|
||||
# images as GPU tensors (with `device`) keeps the whole vision pipeline on-device
|
||||
# and avoids a GPU->CPU->GPU roundtrip. The image tensors are forwarded through
|
||||
# apply_chat_template untouched into Qwen3VLProcessor.__call__.
|
||||
# do_rescale=False: images already arrive as float in [0, 1] (the dataset decoder
|
||||
# yields float32/255 and VISUAL normalization is IDENTITY), so we skip the
|
||||
# processor's /255 rescale instead of round-tripping through uint8.
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
processor_kwargs={
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
"device": self.model.device,
|
||||
"do_rescale": False,
|
||||
},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
def to_pixel_values(image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Prepare an image/video tensor for the fast processors (used with do_rescale=False).
|
||||
|
||||
The dataset decoder yields float32 in [0, 1] (channels-first) and VISUAL
|
||||
normalization is IDENTITY, so the tensor already arrives in [0, 1]; we pass it
|
||||
through as float and let the processors normalize (no rescale, no uint8
|
||||
quantization). A single channel is expanded to 3 to match the RGB processors.
|
||||
|
||||
Works for any channels-first layout (channel dim is -3): [C, H, W], [B, C, H, W],
|
||||
[T, C, H, W], [B, V, T, C, H, W], ...
|
||||
"""
|
||||
image = image_tensor.detach().float()
|
||||
if image.shape[-3] == 1:
|
||||
repeats = [1] * image.ndim
|
||||
repeats[-3] = 3
|
||||
image = image.repeat(*repeats)
|
||||
return image
|
||||
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -23,6 +23,7 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -49,6 +50,7 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -66,6 +68,8 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
|
||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,10 +56,14 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
@@ -111,6 +112,18 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
|
||||
@@ -8,7 +8,6 @@ from types import SimpleNamespace
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -191,7 +190,7 @@ class _FakeQwenInterface(nn.Module):
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
images: list[list[Tensor]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
@@ -214,12 +213,13 @@ class _FakeQwenInterface(nn.Module):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
def to_pixel_values(image_tensor: Tensor) -> Tensor:
|
||||
image = image_tensor.detach().float()
|
||||
if image.shape[-3] == 1:
|
||||
repeats = [1] * image.ndim
|
||||
repeats[-3] = 3
|
||||
image = image.repeat(*repeats)
|
||||
return image
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
@@ -242,12 +242,14 @@ class _FakeVideoEncoder(nn.Module):
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
||||
def __call__(self, videos, return_tensors: str, device=None, **kwargs) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
if isinstance(videos, list):
|
||||
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||
else:
|
||||
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||
if device is not None:
|
||||
pixel_values = pixel_values.to(device)
|
||||
return {"pixel_values_videos": pixel_values}
|
||||
|
||||
|
||||
|
||||
@@ -211,40 +211,42 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
|
||||
|
||||
|
||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||
from PIL import Image
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
examples = policy._prepare_model_inputs(make_train_batch())
|
||||
inputs = policy._prepare_model_inputs(make_train_batch())
|
||||
|
||||
assert len(examples) == BATCH_SIZE
|
||||
for ex in examples:
|
||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||
assert ex["state"].shape == (1, STATE_DIM)
|
||||
assert set(inputs) >= {"images", "instructions", "videos", "actions", "state"}
|
||||
# images: per-sample, per-view [C, H, W] float tensors (kept as a list for Qwen messages)
|
||||
assert len(inputs["images"]) == BATCH_SIZE and len(inputs["images"][0]) == 1
|
||||
img = inputs["images"][0][0]
|
||||
assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 and img.ndim == 3
|
||||
assert len(inputs["instructions"]) == BATCH_SIZE
|
||||
# videos: batched [B, V, T, C, H, W] float
|
||||
assert inputs["videos"].ndim == 6 and inputs["videos"].shape[0] == BATCH_SIZE
|
||||
assert inputs["videos"].dtype == torch.float32
|
||||
assert inputs["actions"].shape == (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
assert inputs["state"].shape == (BATCH_SIZE, 1, STATE_DIM)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||
assert "action" not in ex
|
||||
assert "image" in ex and "video" in ex and "lang" in ex
|
||||
inputs = policy._prepare_model_inputs(make_inference_batch())
|
||||
assert "actions" not in inputs and "action_is_pad" not in inputs
|
||||
assert {"images", "instructions", "state"} <= set(inputs)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch["task"]
|
||||
examples = policy._prepare_model_inputs(batch)
|
||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||
instructions = policy._prepare_model_inputs(batch)["instructions"]
|
||||
assert all(isinstance(s, str) and len(s) > 0 for s in instructions)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
batch["task"] = "open the drawer"
|
||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||
assert policy._prepare_model_inputs(batch)["instructions"] == ["open the drawer"] * BATCH_SIZE
|
||||
|
||||
|
||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||
@@ -253,7 +255,7 @@ def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: N
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch[OBS_STATE]
|
||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||
assert "state" not in policy._prepare_model_inputs(batch)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -446,14 +448,14 @@ def test_postprocessor_applied_after_predict_action_chunk(
|
||||
"""
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||
raw_actions = torch.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=torch.float32)
|
||||
|
||||
cfg = make_config()
|
||||
cfg.clip_normalized_actions = False
|
||||
cfg.binarize_gripper_action = False
|
||||
policy = VLAJEPAPolicy(cfg)
|
||||
policy.eval()
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.clone())
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||
@@ -564,9 +566,9 @@ def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_model
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=1))
|
||||
@@ -587,9 +589,9 @@ def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: No
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=3))
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
assert EpisodicStrategyConfig().type == "episodic"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategy,
|
||||
EpisodicStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
|
||||
Reference in New Issue
Block a user