mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c1332ac37e | |||
| 31ddb8f493 | |||
| 877847c90e | |||
| 49755a3d9e | |||
| 09808183ca |
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
|||||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||||
|
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||||
|
|
||||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||||
|
|||||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
|||||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||||
| `--teleop.type` | **Required.** Teleoperator type |
|
| `--teleop.type` | **Required.** Teleoperator type |
|
||||||
|
|
||||||
|
### Episodic (`--strategy.type=episodic`)
|
||||||
|
|
||||||
|
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=episodic \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--teleop.type=so100_leader \
|
||||||
|
--teleop.port=/dev/ttyACM1 \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||||
|
--dataset.num_episodes=20 \
|
||||||
|
--dataset.episode_time_s=30 \
|
||||||
|
--dataset.reset_time_s=10 \
|
||||||
|
--dataset.single_task="Pick up the red cube"
|
||||||
|
```
|
||||||
|
|
||||||
|
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||||
|
|
||||||
|
**Keyboard controls:**
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ----------- | -------------------------------- |
|
||||||
|
| `→` (right) | End the current episode early |
|
||||||
|
| `←` (left) | Discard episode and re-record it |
|
||||||
|
| `ESC` | Stop the recording session |
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `--dataset.num_episodes` | Number of episodes to record |
|
||||||
|
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||||
|
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||||
|
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||||
|
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||||
|
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Inference Backends
|
## Inference Backends
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Teleoperator smooth handover helpers
|
||||||
|
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||||
|
# being a root module.
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_supports_feedback(teleop) -> bool:
|
||||||
|
"""Return True when the teleop can receive position feedback (is actuated).
|
||||||
|
|
||||||
|
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||||
|
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||||
|
|
||||||
|
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
bool(teleop.feedback_features)
|
||||||
|
and hasattr(teleop, "disable_torque")
|
||||||
|
and hasattr(teleop, "enable_torque")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||||
|
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||||
|
|
||||||
|
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||||
|
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||||
|
|
||||||
|
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||||
|
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||||
|
the robot action key space directly.
|
||||||
|
|
||||||
|
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||||
|
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||||
|
"""
|
||||||
|
teleop.enable_torque()
|
||||||
|
current = teleop.get_action()
|
||||||
|
steps = max(int(duration_s * fps), 1)
|
||||||
|
|
||||||
|
for step in range(steps + 1):
|
||||||
|
t = step / steps
|
||||||
|
interp = {
|
||||||
|
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||||
|
}
|
||||||
|
teleop.send_feedback(interp)
|
||||||
|
time.sleep(1 / fps)
|
||||||
|
|
||||||
|
|
||||||
|
def follower_smooth_move_to(
|
||||||
|
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||||
|
) -> None:
|
||||||
|
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||||
|
|
||||||
|
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||||
|
the follower, the follower is brought to the teleop's current pose so the
|
||||||
|
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||||
|
|
||||||
|
Both ``current`` and ``target`` must be in the robot action key space
|
||||||
|
(i.e. the output of ``robot_action_processor``).
|
||||||
|
"""
|
||||||
|
steps = max(int(duration_s * fps), 1)
|
||||||
|
|
||||||
|
for step in range(steps + 1):
|
||||||
|
t = step / steps
|
||||||
|
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||||
|
robot.send_action(interp)
|
||||||
|
time.sleep(1 / fps)
|
||||||
|
|||||||
@@ -17,12 +17,10 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from PIL import Image
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
@@ -55,12 +53,13 @@ class VLAJEPAModel(nn.Module):
|
|||||||
- DiT-B: flow-matching action head for future action prediction
|
- DiT-B: flow-matching action head for future action prediction
|
||||||
- V-JEPA: world model for video frame prediction
|
- V-JEPA: world model for video frame prediction
|
||||||
|
|
||||||
Input: List[dict] native format (same as original starVLA)
|
Inputs are batched tensors kept on the model device
|
||||||
- "image": List[PIL.Image] (multi-view images)
|
- images: List[List[Tensor [C, H, W]]] (float [0,1]) — per sample, per view (Qwen messages)
|
||||||
- "video": np.ndarray [V, T, H, W, 3]
|
- instructions: List[str]
|
||||||
- "lang": str (task instruction)
|
- videos: Tensor [B, V, T, C, H, W] (float [0,1], world model only)
|
||||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
- actions: Tensor [B, T, action_dim] (optional, training only)
|
||||||
- "state": np.ndarray [1, state_dim] (optional)
|
- state: Tensor [B, 1, state_dim] (optional)
|
||||||
|
- action_is_pad: Tensor [B, T] (optional)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||||
@@ -161,166 +160,123 @@ class VLAJEPAModel(nn.Module):
|
|||||||
|
|
||||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||||
|
|
||||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
def _encode_qwen(
|
||||||
"""
|
self, images: list[list[Tensor]], instructions: list[str], *, need_action_tokens: bool
|
||||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
) -> tuple[Tensor, Tensor, Tensor | None]:
|
||||||
|
"""Run Qwen and gather the embodied-action (and optionally action) token hidden states."""
|
||||||
Args:
|
|
||||||
examples: List of per-sample dicts with keys:
|
|
||||||
"image" : List[PIL.Image] — multi-view images
|
|
||||||
"video" : np.ndarray [V, T, H, W, 3]
|
|
||||||
"lang" : str — task instruction
|
|
||||||
"action" : np.ndarray [T, action_dim] (optional)
|
|
||||||
"state" : np.ndarray [1, state_dim] (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
|
||||||
"""
|
|
||||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
|
||||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
|
||||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
|
||||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
|
||||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
|
||||||
actions = [ex["action"] for ex in examples] if has_action else None
|
|
||||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
|
||||||
state = [ex["state"] for ex in examples] if has_state else None
|
|
||||||
action_is_pad = (
|
|
||||||
[ex["action_is_pad"] for ex in examples]
|
|
||||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
|
||||||
batch_videos = np.stack(batch_videos)
|
|
||||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
|
||||||
|
|
||||||
# Adjust number of views for the world model:
|
|
||||||
# - fewer views than expected: duplicate the first view to fill up
|
|
||||||
# - more views than expected: keep only the first num_views_world_model views
|
|
||||||
num_views_world_model = self.config.jepa_tubelet_size
|
|
||||||
if batch_videos.shape[1] < num_views_world_model:
|
|
||||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
|
||||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
|
||||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
|
||||||
elif batch_videos.shape[1] > num_views_world_model:
|
|
||||||
batch_videos = batch_videos[:, :num_views_world_model]
|
|
||||||
|
|
||||||
# ---- Step 1: QwenVL encode (same as original) ----
|
|
||||||
qwen_inputs = self.qwen.build_inputs(
|
qwen_inputs = self.qwen.build_inputs(
|
||||||
images=batch_images,
|
images=images,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
action_prompt=self.replace_prompt,
|
action_prompt=self.replace_prompt,
|
||||||
embodied_prompt=self.embodied_replace_prompt,
|
embodied_prompt=self.embodied_replace_prompt,
|
||||||
)
|
)
|
||||||
|
input_ids = qwen_inputs["input_ids"]
|
||||||
# Locate embodied-action tokens (always needed for action head)
|
embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True)
|
||||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
action_idx = None
|
||||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
if need_action_tokens:
|
||||||
|
action_mask = torch.isin(input_ids, torch.tensor(self.action_token_ids, device=input_ids.device))
|
||||||
# Locate action tokens (only needed for world model predictor)
|
action_idx = action_mask.nonzero(as_tuple=True)
|
||||||
if self.config.enable_world_model:
|
|
||||||
action_mask = torch.isin(
|
|
||||||
qwen_inputs["input_ids"],
|
|
||||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
|
||||||
)
|
|
||||||
action_indices = action_mask.nonzero(as_tuple=True)
|
|
||||||
|
|
||||||
device_type = next(self.parameters()).device.type
|
device_type = next(self.parameters()).device.type
|
||||||
|
|
||||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||||
b, _, h = last_hidden.shape
|
b, _, h = last_hidden.shape
|
||||||
|
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
|
||||||
|
action_tokens = (
|
||||||
|
last_hidden[action_idx[0], action_idx[1], :].view(b, -1, h)
|
||||||
|
if action_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return last_hidden, embodied_action_tokens, action_tokens
|
||||||
|
|
||||||
if self.config.enable_world_model:
|
def _world_model_loss(self, videos: Tensor, action_tokens: Tensor) -> Tensor:
|
||||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
"""JEPA encode + predictor L1 loss. `videos` is [B, V, T, C, H, W] float in [0, 1]."""
|
||||||
|
# Match the world model's expected view count: pad with the first view, or trim extras.
|
||||||
|
num_views = self.config.jepa_tubelet_size
|
||||||
|
if videos.shape[1] < num_views:
|
||||||
|
missing = num_views - videos.shape[1]
|
||||||
|
videos = torch.cat([videos, videos[:, :1].repeat(1, missing, 1, 1, 1, 1)], dim=1)
|
||||||
|
elif videos.shape[1] > num_views:
|
||||||
|
videos = videos[:, :num_views]
|
||||||
|
|
||||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
b, v, t_frames, c, h_img, w_img = videos.shape
|
||||||
|
flat = videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||||
|
# Fast (torchvision) video processor on-device, do_rescale=False (frames already in [0, 1]).
|
||||||
|
video_pixels = self.video_processor(
|
||||||
|
videos=list(flat),
|
||||||
|
return_tensors="pt",
|
||||||
|
device=self.video_encoder.device,
|
||||||
|
do_rescale=False,
|
||||||
|
)["pixel_values_videos"] # [B*V, T, C, H, W]
|
||||||
|
|
||||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
with torch.no_grad():
|
||||||
device_wm = last_hidden.device
|
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||||
if not self.config.enable_world_model:
|
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||||
|
|
||||||
|
tubelet_size = self.video_encoder.config.tubelet_size
|
||||||
|
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||||
|
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||||
|
if t_enc_total < 2:
|
||||||
|
return torch.zeros((), device=video_embeddings.device)
|
||||||
|
|
||||||
|
# Shift-by-one JEPA split: input_states = positions 0..T-2, gt_states = positions 1..T-1
|
||||||
|
t_enc_ctx = t_enc_total - 1
|
||||||
|
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||||
|
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||||
|
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||||
|
|
||||||
|
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||||
|
if action_tokens.shape[1] < expected_actions:
|
||||||
|
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||||
|
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||||
|
|
||||||
|
predicted_states = self.video_predictor(
|
||||||
|
input_states.float(), action_tokens[:, :expected_actions].float()
|
||||||
|
)
|
||||||
|
return F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||||
|
|
||||||
|
def _action_loss(
|
||||||
|
self,
|
||||||
|
embodied_action_tokens: Tensor,
|
||||||
|
actions: Tensor,
|
||||||
|
state: Tensor | None,
|
||||||
|
action_is_pad: Tensor | None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
|
||||||
|
device_type = next(self.parameters()).device.type
|
||||||
|
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||||
|
r = self.config.repeated_diffusion_steps
|
||||||
|
horizon = self.config.chunk_size
|
||||||
|
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
|
||||||
|
embodied = embodied_action_tokens.repeat(r, 1, 1)
|
||||||
|
state_rep = state.to(embodied_action_tokens.dtype).repeat(r, 1, 1) if state is not None else None
|
||||||
|
pad_rep = action_is_pad[:, -horizon:].repeat(r, 1) if action_is_pad is not None else None
|
||||||
|
return self.action_model(embodied, actions_target, state_rep, pad_rep)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
images: list[list[Tensor]],
|
||||||
|
instructions: list[str],
|
||||||
|
videos: Tensor | None = None,
|
||||||
|
actions: Tensor | None = None,
|
||||||
|
state: Tensor | None = None,
|
||||||
|
action_is_pad: Tensor | None = None,
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
"""Native forward: Qwen encode → optional world-model loss → optional action-head loss."""
|
||||||
|
last_hidden, embodied_action_tokens, action_tokens = self._encode_qwen(
|
||||||
|
images, instructions, need_action_tokens=self.config.enable_world_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.enable_world_model:
|
||||||
|
wm_loss = self._world_model_loss(videos, action_tokens)
|
||||||
else:
|
else:
|
||||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
wm_loss = torch.zeros((), device=last_hidden.device)
|
||||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
|
||||||
|
|
||||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
if actions is None:
|
||||||
"pixel_values_videos"
|
|
||||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
|
||||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
|
||||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
|
||||||
|
|
||||||
tubelet_size = self.video_encoder.config.tubelet_size
|
|
||||||
device_wm = video_embeddings.device
|
|
||||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
|
||||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
|
||||||
|
|
||||||
if t_enc_total < 2:
|
|
||||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
|
||||||
else:
|
|
||||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
|
||||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
|
||||||
t_enc_ctx = t_enc_total - 1
|
|
||||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
|
||||||
|
|
||||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
|
||||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
|
||||||
|
|
||||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
|
||||||
if action_tokens.shape[1] < expected_actions:
|
|
||||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
|
||||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
|
||||||
|
|
||||||
predicted_states = self.video_predictor(
|
|
||||||
input_states.float(),
|
|
||||||
action_tokens[:, :expected_actions].float(),
|
|
||||||
)
|
|
||||||
|
|
||||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
|
||||||
|
|
||||||
if not has_action:
|
|
||||||
return {"wm_loss": wm_loss}
|
return {"wm_loss": wm_loss}
|
||||||
|
|
||||||
# ---- Step 4: Action Head ----
|
action_loss = self._action_loss(embodied_action_tokens, actions, state, action_is_pad)
|
||||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
|
||||||
actions_tensor = torch.tensor(
|
|
||||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
|
||||||
) # [B, T_full, action_dim]
|
|
||||||
action_horizon = self.config.chunk_size
|
|
||||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
|
||||||
|
|
||||||
state_tensor = None
|
|
||||||
if state is not None:
|
|
||||||
state_tensor = torch.tensor(
|
|
||||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
|
||||||
) # [B, 1, state_dim]
|
|
||||||
|
|
||||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
|
||||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
|
||||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
|
||||||
if state_tensor is not None:
|
|
||||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
|
||||||
|
|
||||||
action_is_pad_rep = None
|
|
||||||
if action_is_pad is not None:
|
|
||||||
pad_tensor = torch.stack(
|
|
||||||
[
|
|
||||||
p.to(actions_target.device)
|
|
||||||
if isinstance(p, Tensor)
|
|
||||||
else torch.tensor(p, device=actions_target.device)
|
|
||||||
for p in action_is_pad
|
|
||||||
]
|
|
||||||
) # [B, T_full]
|
|
||||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
|
||||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
|
||||||
|
|
||||||
action_loss = self.action_model(
|
|
||||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||||
|
|
||||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||||
@@ -328,58 +284,24 @@ class VLAJEPAModel(nn.Module):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action(
|
def predict_action(
|
||||||
self,
|
self,
|
||||||
batch_images: list[list[Image.Image]],
|
images: list[list[Tensor]],
|
||||||
instructions: list[str],
|
instructions: list[str],
|
||||||
state: np.ndarray | None = None,
|
state: Tensor | None = None,
|
||||||
) -> np.ndarray:
|
) -> Tensor:
|
||||||
"""
|
"""Predict an action chunk. `images` is per-sample, per-view float [0,1] [C, H, W] tensors."""
|
||||||
Native action prediction following original VLA_JEPA.predict_action.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
|
||||||
instructions: Task instructions, one per sample.
|
|
||||||
state: Optional [B, state_dim] numpy array.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
|
||||||
"""
|
|
||||||
if self.config.resize_images_to is not None:
|
if self.config.resize_images_to is not None:
|
||||||
height, width = self.config.resize_images_to
|
height, width = self.config.resize_images_to
|
||||||
resampling = getattr(Image, "Resampling", Image).BOX
|
images = [
|
||||||
batch_images = [
|
[F.interpolate(img[None], size=(height, width), mode="area")[0] for img in views]
|
||||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
for views in images
|
||||||
for sample_images in batch_images
|
|
||||||
]
|
]
|
||||||
|
|
||||||
qwen_inputs = self.qwen.build_inputs(
|
_, embodied_action_tokens, _ = self._encode_qwen(images, instructions, need_action_tokens=False)
|
||||||
images=batch_images,
|
state = state.to(embodied_action_tokens.dtype) if state is not None else None
|
||||||
instructions=instructions,
|
return self.action_model.predict_action(
|
||||||
action_prompt=self.replace_prompt,
|
embodied_action_tokens.float(), state.float() if state is not None else None
|
||||||
embodied_prompt=self.embodied_replace_prompt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
|
||||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
|
||||||
|
|
||||||
device_type = next(self.parameters()).device.type
|
|
||||||
|
|
||||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
|
||||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
|
||||||
b, _, h = last_hidden.shape
|
|
||||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
|
||||||
|
|
||||||
state_tensor = None
|
|
||||||
if state is not None:
|
|
||||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
|
||||||
device=last_hidden.device, dtype=last_hidden.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_actions = self.action_model.predict_action(
|
|
||||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
|
||||||
) # [B, action_horizon, action_dim]
|
|
||||||
|
|
||||||
return pred_actions.detach().cpu().numpy()
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||||
@@ -390,9 +312,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
"""
|
"""
|
||||||
LeRobot adapter for VLA-JEPA.
|
LeRobot adapter for VLA-JEPA.
|
||||||
|
|
||||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
Converts LeRobot's standard batch format (dict[str, Tensor]) to the batched tensors
|
||||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
the native model expects (keeping everything on-device), calls the native model, and
|
||||||
back to LeRobot format.
|
converts outputs back to LeRobot format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = VLAJEPAConfig
|
config_class = VLAJEPAConfig
|
||||||
@@ -419,9 +341,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# ---- Format Conversion: LeRobot → Native ----
|
# ---- Format Conversion: LeRobot → Native ----
|
||||||
|
|
||||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Any]:
|
||||||
"""
|
"""Convert a LeRobot batch to the model's batched, on-device inputs.
|
||||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
|
||||||
|
|
||||||
LeRobot format:
|
LeRobot format:
|
||||||
batch = {
|
batch = {
|
||||||
@@ -431,65 +352,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
"task": str | List[str], (optional instruction)
|
"task": str | List[str], (optional instruction)
|
||||||
}
|
}
|
||||||
|
|
||||||
Native format (List[dict]):
|
Returns the kwargs for `VLAJEPAModel.forward` / `.predict_action` (everything stays
|
||||||
{
|
on the batch device; no per-sample shredding): `images` (per-sample, per-view list for
|
||||||
"image": List[PIL.Image], # multi-view images per sample
|
Qwen messages), `instructions`, and the batched `videos` / `actions` / `state` /
|
||||||
"video": np.ndarray [V, T, H, W, 3],
|
`action_is_pad` when present.
|
||||||
"lang": str, # task instruction
|
|
||||||
"action": np.ndarray [T, action_dim], # optional
|
|
||||||
"state": np.ndarray [1, state_dim], # optional
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
# Determine batch size from the first image feature
|
|
||||||
image_keys = list(self.config.image_features.keys())
|
image_keys = list(self.config.image_features.keys())
|
||||||
if not image_keys:
|
if not image_keys:
|
||||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||||
first_key = image_keys[0]
|
batch_size = batch[image_keys[0]].shape[0]
|
||||||
first_tensor = batch[first_key]
|
|
||||||
batch_size = first_tensor.shape[0]
|
|
||||||
|
|
||||||
# ---- Collect images per sample ----
|
# Current-frame image per view ([B, C, H, W]); regroup per sample for Qwen messages.
|
||||||
# images_per_sample[b][v] = PIL.Image for view v
|
frames = []
|
||||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
t = batch[key]
|
||||||
if tensor.ndim == 5:
|
if t.ndim == 5: # [B, T, C, H, W] -> current observation (delta=0)
|
||||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
t = t[:, 0]
|
||||||
# index 0 is the current observation (delta=0)
|
frames.append(self.model.qwen.to_pixel_values(t))
|
||||||
tensor = tensor[:, 0]
|
images = [[frame[b] for frame in frames] for b in range(batch_size)]
|
||||||
for b in range(batch_size):
|
|
||||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
|
||||||
|
|
||||||
# ---- Collect videos per sample ----
|
|
||||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
|
||||||
# Check whether any image feature has a time dimension
|
|
||||||
video_source = None
|
|
||||||
for k in image_keys:
|
|
||||||
if k in batch:
|
|
||||||
video_source = batch[k] # Use first available for shape inspection
|
|
||||||
break
|
|
||||||
|
|
||||||
if video_source is None:
|
|
||||||
raise ValueError("No image data found in batch for video construction.")
|
|
||||||
|
|
||||||
videos_per_sample = []
|
|
||||||
for b in range(batch_size):
|
|
||||||
sample_views = []
|
|
||||||
for k in image_keys:
|
|
||||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
|
||||||
if t.ndim == 3:
|
|
||||||
t = t.unsqueeze(0) # [1, C, H, W]
|
|
||||||
# Convert to [T, H, W, 3] numpy
|
|
||||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
|
||||||
# Clamp to [0, 255]
|
|
||||||
if t_np.max() <= 1.0:
|
|
||||||
t_np = t_np * 255.0
|
|
||||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
|
||||||
sample_views.append(t_np)
|
|
||||||
# Stack views: [V, T, H, W, 3]
|
|
||||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
|
||||||
|
|
||||||
# ---- Collect instructions ----
|
|
||||||
tasks = batch.get("task")
|
tasks = batch.get("task")
|
||||||
if tasks is None:
|
if tasks is None:
|
||||||
instructions = ["Execute the robot action."] * batch_size
|
instructions = ["Execute the robot action."] * batch_size
|
||||||
@@ -498,52 +379,32 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
instructions = list(tasks)
|
instructions = list(tasks)
|
||||||
|
|
||||||
# ---- Collect actions (training only) ----
|
inputs: dict[str, Any] = {"images": images, "instructions": instructions}
|
||||||
actions_list = None
|
|
||||||
action_is_pad_list = None
|
|
||||||
actions_tensor = batch.get(ACTION)
|
|
||||||
if actions_tensor is not None:
|
|
||||||
if actions_tensor.ndim == 2:
|
|
||||||
actions_tensor = actions_tensor.unsqueeze(1)
|
|
||||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
|
||||||
action_is_pad_tensor = batch.get("action_is_pad")
|
|
||||||
if action_is_pad_tensor is not None:
|
|
||||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
|
||||||
|
|
||||||
# ---- Collect state ----
|
# Videos [B, V, T, C, H, W] - only assembled when the world model consumes them.
|
||||||
state_list = None
|
if self.model.config.enable_world_model:
|
||||||
state_tensor = batch.get(OBS_STATE)
|
views = [batch[k].unsqueeze(1) if batch[k].ndim == 4 else batch[k] for k in image_keys]
|
||||||
if state_tensor is not None:
|
inputs["videos"] = self.model.qwen.to_pixel_values(torch.stack(views, dim=1))
|
||||||
if state_tensor.ndim > 2:
|
|
||||||
state_tensor = state_tensor[:, -1, :]
|
|
||||||
if state_tensor.ndim == 2:
|
|
||||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
|
||||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
|
||||||
|
|
||||||
# ---- Assemble native examples ----
|
actions = batch.get(ACTION)
|
||||||
examples = []
|
if actions is not None:
|
||||||
for b in range(batch_size):
|
inputs["actions"] = (actions.unsqueeze(1) if actions.ndim == 2 else actions).float()
|
||||||
example = {
|
if (pad := batch.get("action_is_pad")) is not None:
|
||||||
"image": images_per_sample[b],
|
inputs["action_is_pad"] = pad
|
||||||
"video": videos_per_sample[b],
|
|
||||||
"lang": instructions[b],
|
|
||||||
}
|
|
||||||
if actions_list is not None:
|
|
||||||
example["action"] = actions_list[b]
|
|
||||||
if action_is_pad_list is not None:
|
|
||||||
example["action_is_pad"] = action_is_pad_list[b]
|
|
||||||
if state_list is not None:
|
|
||||||
example["state"] = state_list[b]
|
|
||||||
examples.append(example)
|
|
||||||
|
|
||||||
return examples
|
state = batch.get(OBS_STATE)
|
||||||
|
if state is not None:
|
||||||
|
if state.ndim > 2:
|
||||||
|
state = state[:, -1, :]
|
||||||
|
inputs["state"] = (state.unsqueeze(1) if state.ndim == 2 else state).float() # [B, 1, dim]
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
# ---- LeRobot Policy Interface ----
|
# ---- LeRobot Policy Interface ----
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||||
examples = self._prepare_model_inputs(batch)
|
native_output = self.model.forward(**self._prepare_model_inputs(batch))
|
||||||
native_output = self.model.forward(examples)
|
|
||||||
|
|
||||||
ref = next(iter(native_output.values()))
|
ref = next(iter(native_output.values()))
|
||||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||||
@@ -561,16 +422,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
self.eval()
|
self.eval()
|
||||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
|
||||||
examples = self._prepare_model_inputs(batch)
|
inputs = self._prepare_model_inputs(batch)
|
||||||
batch_images = [ex["image"] for ex in examples]
|
actions = self.model.predict_action(inputs["images"], inputs["instructions"], inputs.get("state"))
|
||||||
instructions = [ex["lang"] for ex in examples]
|
return actions.to(device=self.config.device, dtype=torch.float32)
|
||||||
|
|
||||||
state_np = None
|
|
||||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
|
||||||
state_np = np.stack([ex["state"] for ex in examples])
|
|
||||||
|
|
||||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
|
||||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -78,7 +76,7 @@ class Qwen3VLInterface(torch.nn.Module):
|
|||||||
|
|
||||||
def build_inputs(
|
def build_inputs(
|
||||||
self,
|
self,
|
||||||
images: Sequence[Sequence[Image.Image]],
|
images: Sequence[Sequence[torch.Tensor]],
|
||||||
instructions: Sequence[str],
|
instructions: Sequence[str],
|
||||||
action_prompt: str,
|
action_prompt: str,
|
||||||
embodied_prompt: str,
|
embodied_prompt: str,
|
||||||
@@ -94,24 +92,42 @@ class Qwen3VLInterface(torch.nn.Module):
|
|||||||
content.append({"type": "text", "text": prompt})
|
content.append({"type": "text", "text": prompt})
|
||||||
messages.append([{"role": "user", "content": content}])
|
messages.append([{"role": "user", "content": content}])
|
||||||
|
|
||||||
|
# The Qwen image processor is a torchvision-backed fast processor: passing the
|
||||||
|
# images as GPU tensors (with `device`) keeps the whole vision pipeline on-device
|
||||||
|
# and avoids a GPU->CPU->GPU roundtrip. The image tensors are forwarded through
|
||||||
|
# apply_chat_template untouched into Qwen3VLProcessor.__call__.
|
||||||
|
# do_rescale=False: images already arrive as float in [0, 1] (the dataset decoder
|
||||||
|
# yields float32/255 and VISUAL normalization is IDENTITY), so we skip the
|
||||||
|
# processor's /255 rescale instead of round-tripping through uint8.
|
||||||
batch_inputs = self.processor.apply_chat_template(
|
batch_inputs = self.processor.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
processor_kwargs={
|
||||||
|
"padding": True,
|
||||||
|
"return_tensors": "pt",
|
||||||
|
"device": self.model.device,
|
||||||
|
"do_rescale": False,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return batch_inputs.to(self.model.device)
|
return batch_inputs.to(self.model.device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
def to_pixel_values(image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
image = image_tensor.detach().cpu()
|
"""Prepare an image/video tensor for the fast processors (used with do_rescale=False).
|
||||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
|
||||||
image = image.permute(1, 2, 0)
|
The dataset decoder yields float32 in [0, 1] (channels-first) and VISUAL
|
||||||
image = image.float()
|
normalization is IDENTITY, so the tensor already arrives in [0, 1]; we pass it
|
||||||
if image.max() <= 1.0:
|
through as float and let the processors normalize (no rescale, no uint8
|
||||||
image = image * 255.0
|
quantization). A single channel is expanded to 3 to match the RGB processors.
|
||||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
|
||||||
if image.shape[-1] == 1:
|
Works for any channels-first layout (channel dim is -3): [C, H, W], [B, C, H, W],
|
||||||
image = np.repeat(image, 3, axis=-1)
|
[T, C, H, W], [B, V, T, C, H, W], ...
|
||||||
return Image.fromarray(image)
|
"""
|
||||||
|
image = image_tensor.detach().float()
|
||||||
|
if image.shape[-3] == 1:
|
||||||
|
repeats = [1] * image.ndim
|
||||||
|
repeats[-3] = 3
|
||||||
|
image = image.repeat(*repeats)
|
||||||
|
return image
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
|
|
||||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
|
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||||
|
default=None,
|
||||||
|
init=False,
|
||||||
|
repr=False,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, data: TInput) -> TOutput:
|
def __call__(self, data: TInput) -> TOutput:
|
||||||
"""Processes input data through the full pipeline.
|
"""Processes input data through the full pipeline.
|
||||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
transition = processor_step(transition)
|
transition = processor_step(transition)
|
||||||
yield transition
|
yield transition
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
def _get_sanitized_name(self) -> str:
|
||||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
"""Return a filename-safe version of the pipeline name.
|
||||||
|
|
||||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
Returns:
|
||||||
|
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||||
"""
|
"""
|
||||||
config_filename = kwargs.pop("config_filename", None)
|
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||||
|
|
||||||
# Sanitize the pipeline name to create a valid filename prefix.
|
@staticmethod
|
||||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
def _get_state_filename(
|
||||||
|
*,
|
||||||
|
step_index: int,
|
||||||
|
registry_name: str | None,
|
||||||
|
sanitized_name: str,
|
||||||
|
) -> str:
|
||||||
|
"""Return the safetensors filename for one stateful processor step.
|
||||||
|
|
||||||
if config_filename is None:
|
Args:
|
||||||
config_filename = f"{sanitized_name}.json"
|
step_index: The index of the processor step in this pipeline.
|
||||||
|
registry_name: The registered processor step name, if available.
|
||||||
|
sanitized_name: The filename-safe pipeline name.
|
||||||
|
|
||||||
config: dict[str, Any] = {
|
Returns:
|
||||||
|
The state filename used by the existing disk serialization format.
|
||||||
|
"""
|
||||||
|
if registry_name:
|
||||||
|
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||||
|
|
||||||
|
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_state_key(state_filename: str) -> str:
|
||||||
|
"""Return the in-memory state key for a serialized state filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_filename: The `.safetensors` filename from the serialized config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state key used by the in-memory pipeline state dictionary.
|
||||||
|
"""
|
||||||
|
return state_filename.removesuffix(".safetensors")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||||
|
"""Return serialized state filenames in step order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loaded_config: A validated processor pipeline config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||||
|
"""
|
||||||
|
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||||
|
|
||||||
|
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||||
|
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||||
|
current non-empty step state.
|
||||||
|
"""
|
||||||
|
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||||
|
self.steps
|
||||||
|
):
|
||||||
|
return self._serialized_state_filenames
|
||||||
|
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
state_filenames: list[str | None] = []
|
||||||
|
|
||||||
|
for step_index, processor_step in enumerate(self.steps):
|
||||||
|
step_state_dict = processor_step.state_dict()
|
||||||
|
if not step_state_dict:
|
||||||
|
state_filenames.append(None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
state_filenames.append(
|
||||||
|
self._get_state_filename(
|
||||||
|
step_index=step_index,
|
||||||
|
registry_name=registry_name,
|
||||||
|
sanitized_name=sanitized_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(state_filenames)
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return the JSON-serializable pipeline configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||||
|
"""
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
pipeline_config: dict[str, Any] = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"steps": [],
|
"steps": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Iterate through each step to build its configuration entry.
|
|
||||||
for step_index, processor_step in enumerate(self.steps):
|
for step_index, processor_step in enumerate(self.steps):
|
||||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
|
||||||
step_entry: dict[str, Any] = {}
|
step_entry: dict[str, Any] = {}
|
||||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
|
||||||
if registry_name:
|
if registry_name:
|
||||||
step_entry["registry_name"] = registry_name
|
step_entry["registry_name"] = registry_name
|
||||||
else:
|
else:
|
||||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save step configuration if `get_config` is implemented.
|
step_entry["config"] = processor_step.get_config()
|
||||||
if hasattr(processor_step, "get_config"):
|
|
||||||
step_entry["config"] = processor_step.get_config()
|
|
||||||
|
|
||||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
step_state_dict = processor_step.state_dict()
|
||||||
if hasattr(processor_step, "state_dict"):
|
if step_state_dict:
|
||||||
state = processor_step.state_dict()
|
step_entry["state_file"] = self._get_state_filename(
|
||||||
if state:
|
step_index=step_index,
|
||||||
# Clone tensors to avoid modifying the original state.
|
registry_name=registry_name,
|
||||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
sanitized_name=sanitized_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a unique filename for the state file.
|
pipeline_config["steps"].append(step_entry)
|
||||||
if registry_name:
|
|
||||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
|
||||||
else:
|
|
||||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
|
||||||
|
|
||||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
return pipeline_config
|
||||||
step_entry["state_file"] = state_filename
|
|
||||||
|
|
||||||
config["steps"].append(step_entry)
|
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
|
"""Return pipeline state tensors grouped by state key.
|
||||||
|
|
||||||
# Write the main configuration JSON file.
|
Returns:
|
||||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||||
json.dump(config, file_pointer, indent=2)
|
"""
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
for step_index, processor_step in enumerate(self.steps):
|
||||||
|
step_state_dict = processor_step.state_dict()
|
||||||
|
if not step_state_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
state_filename = self._get_state_filename(
|
||||||
|
step_index=step_index,
|
||||||
|
registry_name=registry_name,
|
||||||
|
sanitized_name=sanitized_name,
|
||||||
|
)
|
||||||
|
state_key = self._get_state_key(state_filename)
|
||||||
|
pipeline_state_dict[state_key] = {
|
||||||
|
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return pipeline_state_dict
|
||||||
|
|
||||||
|
def load_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||||
|
) -> None:
|
||||||
|
"""Load pipeline state tensors into the existing steps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||||
|
"""
|
||||||
|
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||||
|
used_state_keys: set[str] = set()
|
||||||
|
|
||||||
|
for step_index, (processor_step, state_filename) in enumerate(
|
||||||
|
zip(self.steps, expected_state_filenames, strict=True)
|
||||||
|
):
|
||||||
|
if state_filename is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_key = self._get_state_key(state_filename)
|
||||||
|
if state_key not in state_dict:
|
||||||
|
raise KeyError(
|
||||||
|
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||||
|
f"Available state keys: {sorted(state_dict.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
processor_step.load_state_dict(state_dict[state_key])
|
||||||
|
used_state_keys.add(state_key)
|
||||||
|
|
||||||
|
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||||
|
if unexpected_state_keys:
|
||||||
|
expected_state_key_set = {
|
||||||
|
self._get_state_key(state_filename)
|
||||||
|
for state_filename in expected_state_filenames
|
||||||
|
if state_filename is not None
|
||||||
|
}
|
||||||
|
raise KeyError(
|
||||||
|
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||||
|
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||||
|
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||||
|
|
||||||
|
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||||
|
"""
|
||||||
|
config_filename = kwargs.pop("config_filename", None)
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
|
||||||
|
if config_filename is None:
|
||||||
|
config_filename = f"{sanitized_name}.json"
|
||||||
|
|
||||||
|
pipeline_config = self.get_config()
|
||||||
|
pipeline_state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||||
|
state_filename = f"{state_key}.safetensors"
|
||||||
|
save_file(step_state_dict, save_directory / state_filename)
|
||||||
|
|
||||||
|
with open(save_directory / config_filename, "w") as file_pointer:
|
||||||
|
json.dump(pipeline_config, file_pointer, indent=2)
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||||
|
|
||||||
# 5. Construct and return the final pipeline instance
|
# 5. Construct and return the final pipeline instance
|
||||||
return cls(
|
pipeline = cls(
|
||||||
steps=steps,
|
steps=steps,
|
||||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||||
)
|
)
|
||||||
|
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: dict[str, Any],
|
||||||
|
*,
|
||||||
|
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
overrides: dict[str, Any] | None = None,
|
||||||
|
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||||
|
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||||
|
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||||
|
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: A config dictionary with the same structure as the saved processor JSON.
|
||||||
|
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||||
|
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||||
|
to_transition: Optional converter from input data to `EnvTransition`.
|
||||||
|
to_output: Optional converter from `EnvTransition` to output data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A processor pipeline built from the config and optional state.
|
||||||
|
"""
|
||||||
|
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||||
|
|
||||||
|
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||||
|
cls._validate_overrides_used(remaining_override_keys, config)
|
||||||
|
|
||||||
|
pipeline = cls(
|
||||||
|
steps=steps,
|
||||||
|
name=config.get("name", "DataProcessorPipeline"),
|
||||||
|
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||||
|
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||||
|
)
|
||||||
|
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||||
|
|
||||||
|
if state_dict is not None:
|
||||||
|
pipeline.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_config(
|
def _load_config(
|
||||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_loaded_config(
|
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
|
||||||
) -> None:
|
|
||||||
"""Validate that a config was loaded and is a valid processor config.
|
"""Validate that a config was loaded and is a valid processor config.
|
||||||
|
|
||||||
This method validates processor config format with intelligent migration detection:
|
This method validates processor config format with intelligent migration detection:
|
||||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: The model identifier (used for migration detection)
|
model_id: The model identifier (used for migration detection)
|
||||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
loaded_config: The loaded config value to validate (may be non-dict)
|
||||||
config_filename: The config filename that was loaded (for error messages)
|
config_filename: The config filename that was loaded (for error messages)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
model_id,
|
model_id,
|
||||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||||
)
|
)
|
||||||
|
loaded_config_description = (
|
||||||
|
list(loaded_config.keys())
|
||||||
|
if isinstance(loaded_config, dict)
|
||||||
|
else type(loaded_config).__name__
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
ImportError: If a step class cannot be imported or found in registry
|
ImportError: If a step class cannot be imported or found in registry
|
||||||
ValueError: If a step cannot be instantiated with its configuration
|
ValueError: If a step cannot be instantiated with its configuration
|
||||||
"""
|
"""
|
||||||
steps: list[ProcessorStep] = []
|
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||||
override_keys = set(overrides.keys())
|
|
||||||
|
|
||||||
for step_entry in loaded_config["steps"]:
|
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||||
# 1. Get step class and key
|
|
||||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
|
||||||
|
|
||||||
# 2. Instantiate step with overrides
|
|
||||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
|
||||||
|
|
||||||
# 3. Load step state if available
|
|
||||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||||
|
|
||||||
# 4. Track used overrides
|
return steps, remaining_override_keys
|
||||||
if step_key in override_keys:
|
|
||||||
override_keys.discard(step_key)
|
|
||||||
|
|
||||||
steps.append(step_instance)
|
@classmethod
|
||||||
|
def _build_steps_from_config(
|
||||||
|
cls,
|
||||||
|
loaded_config: dict[str, Any],
|
||||||
|
overrides: dict[str, Any],
|
||||||
|
) -> tuple[list[ProcessorStep], set[str]]:
|
||||||
|
"""Build processor steps from config without loading tensor state.
|
||||||
|
|
||||||
return steps, override_keys
|
Args:
|
||||||
|
loaded_config: The loaded processor configuration.
|
||||||
|
overrides: User-provided constructor overrides keyed by step key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing instantiated steps and override keys that did not match a step.
|
||||||
|
"""
|
||||||
|
processor_steps: list[ProcessorStep] = []
|
||||||
|
remaining_override_keys = set(overrides.keys())
|
||||||
|
|
||||||
|
for step_entry in loaded_config["steps"]:
|
||||||
|
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||||
|
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||||
|
|
||||||
|
if step_key in remaining_override_keys:
|
||||||
|
remaining_override_keys.discard(step_key)
|
||||||
|
|
||||||
|
processor_steps.append(processor_step)
|
||||||
|
|
||||||
|
return processor_steps, remaining_override_keys
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_processor_config(cls, config: dict) -> bool:
|
def _is_processor_config(cls, config: Any) -> bool:
|
||||||
"""Check if config follows DataProcessorPipeline format.
|
"""Check if config follows DataProcessorPipeline format.
|
||||||
|
|
||||||
This method validates the processor configuration structure:
|
This method validates the processor configuration structure:
|
||||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
# Must have a "steps" field with a list of step configurations
|
# Must have a "steps" field with a list of step configurations
|
||||||
if not isinstance(config.get("steps"), list):
|
if not isinstance(config.get("steps"), list):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from .configs import (
|
|||||||
DAggerKeyboardConfig,
|
DAggerKeyboardConfig,
|
||||||
DAggerPedalConfig,
|
DAggerPedalConfig,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
HighlightStrategyConfig,
|
HighlightStrategyConfig,
|
||||||
RolloutConfig,
|
RolloutConfig,
|
||||||
RolloutStrategyConfig,
|
RolloutStrategyConfig,
|
||||||
@@ -49,6 +50,7 @@ from .inference import (
|
|||||||
from .strategies import (
|
from .strategies import (
|
||||||
BaseStrategy,
|
BaseStrategy,
|
||||||
DAggerStrategy,
|
DAggerStrategy,
|
||||||
|
EpisodicStrategy,
|
||||||
HighlightStrategy,
|
HighlightStrategy,
|
||||||
RolloutStrategy,
|
RolloutStrategy,
|
||||||
SentryStrategy,
|
SentryStrategy,
|
||||||
@@ -66,6 +68,8 @@ __all__ = [
|
|||||||
"HardwareContext",
|
"HardwareContext",
|
||||||
"HighlightStrategy",
|
"HighlightStrategy",
|
||||||
"HighlightStrategyConfig",
|
"HighlightStrategyConfig",
|
||||||
|
"EpisodicStrategy",
|
||||||
|
"EpisodicStrategyConfig",
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"InferenceEngineConfig",
|
"InferenceEngineConfig",
|
||||||
"PolicyContext",
|
"PolicyContext",
|
||||||
|
|||||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
|||||||
upload: str = "KEY_C"
|
upload: str = "KEY_C"
|
||||||
|
|
||||||
|
|
||||||
|
@RolloutStrategyConfig.register_subclass("episodic")
|
||||||
|
@dataclass
|
||||||
|
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||||
|
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||||
|
|
||||||
|
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||||
|
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||||
|
|
||||||
|
Keyboard controls:
|
||||||
|
Right arrow — end current episode or reset phase early
|
||||||
|
Left arrow — discard current episode and re-record
|
||||||
|
Escape — stop recording session
|
||||||
|
|
||||||
|
In between episodes:
|
||||||
|
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||||
|
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This only applies if there are no teleop leaders specified.
|
||||||
|
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||||
|
# Otherwise, leave the robot in its current position.
|
||||||
|
reset_to_initial_position: bool = True
|
||||||
|
|
||||||
|
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||||
|
# When False, fallback to follower -> leader handover.
|
||||||
|
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||||
|
smooth_leader_to_follower_handover: bool = True
|
||||||
|
|
||||||
|
|
||||||
@RolloutStrategyConfig.register_subclass("dagger")
|
@RolloutStrategyConfig.register_subclass("dagger")
|
||||||
@dataclass
|
@dataclass
|
||||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
|||||||
|
|
||||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||||
needs_dataset = isinstance(
|
needs_dataset = isinstance(
|
||||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
self.strategy,
|
||||||
|
(
|
||||||
|
SentryStrategyConfig,
|
||||||
|
HighlightStrategyConfig,
|
||||||
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from .base import BaseStrategy
|
from .base import BaseStrategy
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||||
|
from .episodic import EpisodicStrategy
|
||||||
from .factory import create_strategy
|
from .factory import create_strategy
|
||||||
from .highlight import HighlightStrategy
|
from .highlight import HighlightStrategy
|
||||||
from .sentry import SentryStrategy
|
from .sentry import SentryStrategy
|
||||||
@@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"DAggerPhase",
|
"DAggerPhase",
|
||||||
"DAggerStrategy",
|
"DAggerStrategy",
|
||||||
"HighlightStrategy",
|
"HighlightStrategy",
|
||||||
|
"EpisodicStrategy",
|
||||||
"RolloutStrategy",
|
"RolloutStrategy",
|
||||||
"SentryStrategy",
|
"SentryStrategy",
|
||||||
"create_strategy",
|
"create_strategy",
|
||||||
|
|||||||
@@ -56,10 +56,14 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.control_utils import is_headless
|
from lerobot.common.control_utils import (
|
||||||
|
follower_smooth_move_to,
|
||||||
|
is_headless,
|
||||||
|
teleop_smooth_move_to,
|
||||||
|
teleop_supports_feedback,
|
||||||
|
)
|
||||||
from lerobot.datasets import VideoEncodingManager
|
from lerobot.datasets import VideoEncodingManager
|
||||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||||
from lerobot.teleoperators import Teleoperator
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.feature_utils import build_dataset_frame
|
from lerobot.utils.feature_utils import build_dataset_frame
|
||||||
from lerobot.utils.import_utils import _pynput_available
|
from lerobot.utils.import_utils import _pynput_available
|
||||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
|||||||
|
|
||||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||||
from ..context import RolloutContext
|
from ..context import RolloutContext
|
||||||
from ..robot_wrapper import ThreadSafeRobot
|
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
|
|
||||||
PYNPUT_AVAILABLE = _pynput_available
|
PYNPUT_AVAILABLE = _pynput_available
|
||||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
|||||||
self.upload_requested.clear()
|
self.upload_requested.clear()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Teleoperator helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
|
||||||
"""Return True when the teleop can receive position feedback (is actuated).
|
|
||||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
bool(teleop.feedback_features)
|
|
||||||
and hasattr(teleop, "disable_torque")
|
|
||||||
and hasattr(teleop, "enable_torque")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _teleop_smooth_move_to(
|
|
||||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
|
||||||
) -> None:
|
|
||||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
|
||||||
|
|
||||||
Requires the teleoperator to support feedback
|
|
||||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
|
||||||
|
|
||||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
|
||||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
|
||||||
"""
|
|
||||||
teleop.enable_torque()
|
|
||||||
current = teleop.get_action()
|
|
||||||
steps = max(int(duration_s * fps), 1)
|
|
||||||
|
|
||||||
for step in range(steps + 1):
|
|
||||||
t = step / steps
|
|
||||||
interp = {
|
|
||||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
|
||||||
}
|
|
||||||
teleop.send_feedback(interp)
|
|
||||||
time.sleep(1 / fps)
|
|
||||||
|
|
||||||
|
|
||||||
def _follower_smooth_move_to(
|
|
||||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
|
||||||
) -> None:
|
|
||||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
|
||||||
|
|
||||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
|
||||||
to the follower, we bring the follower to the teleop's current pose.
|
|
||||||
Both ``current`` and ``target`` must be in robot-action key space.
|
|
||||||
"""
|
|
||||||
steps = max(int(duration_s * fps), 1)
|
|
||||||
|
|
||||||
for step in range(steps + 1):
|
|
||||||
t = step / steps
|
|
||||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
|
||||||
robot.send_action(interp)
|
|
||||||
time.sleep(1 / fps)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Input device handlers
|
# Input device handlers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
logger.info("Pausing engine - robot holds position")
|
logger.info("Pausing engine - robot holds position")
|
||||||
engine.pause()
|
engine.pause()
|
||||||
|
|
||||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||||
logger.info("Smooth handover: moving leader arm to follower position")
|
logger.info("Smooth handover: moving leader arm to follower position")
|
||||||
_teleop_smooth_move_to(teleop, prev_action)
|
teleop_smooth_move_to(teleop, prev_action)
|
||||||
|
|
||||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||||
logger.info("Entering correction mode - human teleop control")
|
logger.info("Entering correction mode - human teleop control")
|
||||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||||
logger.info("Smooth handover: sliding follower to teleop position")
|
logger.info("Smooth handover: sliding follower to teleop position")
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
teleop_action = teleop.get_action()
|
teleop_action = teleop.get_action()
|
||||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||||
target = ctx.processors.robot_action_processor((processed, obs))
|
target = ctx.processors.robot_action_processor((processed, obs))
|
||||||
_follower_smooth_move_to(robot, prev_action, target)
|
follower_smooth_move_to(robot, prev_action, target)
|
||||||
|
|
||||||
# unlock the teleop for human control
|
# unlock the teleop for human control
|
||||||
if _teleop_supports_feedback(teleop):
|
if teleop_supports_feedback(teleop):
|
||||||
teleop.disable_torque()
|
teleop.disable_torque()
|
||||||
|
|
||||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||||
if _teleop_supports_feedback(teleop):
|
if teleop_supports_feedback(teleop):
|
||||||
teleop.enable_torque()
|
teleop.enable_torque()
|
||||||
|
|
||||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
engine.resume()
|
engine.resume()
|
||||||
|
|
||||||
# release teleop before resuming the policy
|
# release teleop before resuming the policy
|
||||||
if _teleop_supports_feedback(teleop):
|
if teleop_supports_feedback(teleop):
|
||||||
teleop.disable_torque()
|
teleop.disable_torque()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -0,0 +1,335 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||||
|
|
||||||
|
- Policy drives the robot during each recording episode.
|
||||||
|
- An optional teleoperator can drive the robot during reset phases so the
|
||||||
|
operator can bring the environment back to its starting configuration.
|
||||||
|
If no teleop is connected the robot stays in its current position.
|
||||||
|
- Keyboard controls:
|
||||||
|
|
||||||
|
Right arrow — end the current episode or reset phase early
|
||||||
|
Left arrow — discard the current episode and re-record it
|
||||||
|
Escape — stop the recording session
|
||||||
|
|
||||||
|
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lerobot.common.control_utils import (
|
||||||
|
follower_smooth_move_to,
|
||||||
|
init_keyboard_listener,
|
||||||
|
is_headless,
|
||||||
|
teleop_smooth_move_to,
|
||||||
|
teleop_supports_feedback,
|
||||||
|
)
|
||||||
|
from lerobot.datasets import VideoEncodingManager
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
from lerobot.utils.visualization_utils import log_rerun_data
|
||||||
|
|
||||||
|
from ..configs import EpisodicStrategyConfig
|
||||||
|
from ..context import RolloutContext
|
||||||
|
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicStrategy(RolloutStrategy):
|
||||||
|
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||||
|
|
||||||
|
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||||
|
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||||
|
follows every episode (except the last) so the operator can manually
|
||||||
|
reset the environment. During the reset phase, an optional teleoperator
|
||||||
|
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||||
|
|
||||||
|
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||||
|
the start of each recording episode.
|
||||||
|
|
||||||
|
Keyboard events:
|
||||||
|
right arrow → end current episode or reset phase early
|
||||||
|
left arrow → discard & re-record current episode
|
||||||
|
ESC → stop the session
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: EpisodicStrategyConfig
|
||||||
|
|
||||||
|
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self._listener = None
|
||||||
|
self._events: dict | None = None
|
||||||
|
|
||||||
|
def setup(self, ctx: RolloutContext) -> None:
|
||||||
|
"""Start the inference engine and attach the keyboard listener."""
|
||||||
|
self._init_engine(ctx)
|
||||||
|
self._listener, self._events = init_keyboard_listener()
|
||||||
|
logger.info("Episodic strategy ready")
|
||||||
|
|
||||||
|
def run(self, ctx: RolloutContext) -> None:
|
||||||
|
"""Main multi-episode recording loop."""
|
||||||
|
cfg = ctx.runtime.cfg
|
||||||
|
dataset_cfg = cfg.dataset
|
||||||
|
robot = ctx.hardware.robot_wrapper
|
||||||
|
teleop = ctx.hardware.teleop
|
||||||
|
dataset = ctx.data.dataset
|
||||||
|
events = self._events
|
||||||
|
features = ctx.data.dataset_features
|
||||||
|
|
||||||
|
fps = cfg.fps
|
||||||
|
episode_time_s = dataset_cfg.episode_time_s
|
||||||
|
reset_time_s = dataset_cfg.reset_time_s
|
||||||
|
num_episodes = dataset_cfg.num_episodes
|
||||||
|
single_task = dataset_cfg.single_task or cfg.task
|
||||||
|
play_sounds = cfg.play_sounds
|
||||||
|
|
||||||
|
display_compressed = (
|
||||||
|
True
|
||||||
|
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||||
|
else cfg.display_compressed_images
|
||||||
|
)
|
||||||
|
|
||||||
|
with VideoEncodingManager(dataset):
|
||||||
|
try:
|
||||||
|
recorded_episodes = 0
|
||||||
|
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||||
|
if ctx.runtime.shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||||
|
self._engine.reset()
|
||||||
|
self._interpolator.reset()
|
||||||
|
self._engine.resume()
|
||||||
|
|
||||||
|
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||||
|
self._policy_loop(
|
||||||
|
ctx=ctx,
|
||||||
|
robot=robot,
|
||||||
|
events=events,
|
||||||
|
features=features,
|
||||||
|
fps=fps,
|
||||||
|
control_time_s=episode_time_s,
|
||||||
|
dataset=dataset,
|
||||||
|
single_task=single_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset phase, skip after the last episode (but run when re-recording)
|
||||||
|
if not events["stop_recording"] and (
|
||||||
|
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||||
|
):
|
||||||
|
log_say("Reset the environment", play_sounds)
|
||||||
|
|
||||||
|
if teleop:
|
||||||
|
# Smooth handover so the transition to teleop control is jerk-free.
|
||||||
|
# For actuated teleops: drive the leader arm to the follower's current
|
||||||
|
# position so the operator takes over without fighting the arm.
|
||||||
|
# For non-actuated teleops: slide the follower to the teleop's current
|
||||||
|
# pose instead, since the leader cannot be driven.
|
||||||
|
obs = robot.get_observation()
|
||||||
|
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||||
|
if (
|
||||||
|
teleop_supports_feedback(teleop)
|
||||||
|
and self.config.smooth_leader_to_follower_handover
|
||||||
|
):
|
||||||
|
logger.info("Smooth handover: moving leader arm to follower position")
|
||||||
|
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||||
|
teleop.disable_torque()
|
||||||
|
else:
|
||||||
|
logger.info("Smooth handover: sliding follower to teleop position")
|
||||||
|
teleop_action = teleop.get_action()
|
||||||
|
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||||
|
target = ctx.processors.robot_action_processor((processed, obs))
|
||||||
|
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||||
|
|
||||||
|
elif self.config.reset_to_initial_position:
|
||||||
|
# No teleop: return the robot to its startup position.
|
||||||
|
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||||
|
|
||||||
|
self._reset_loop(
|
||||||
|
ctx=ctx,
|
||||||
|
robot=robot,
|
||||||
|
teleop=teleop,
|
||||||
|
events=events,
|
||||||
|
fps=fps,
|
||||||
|
control_time_s=reset_time_s,
|
||||||
|
display_data=cfg.display_data,
|
||||||
|
display_compressed=display_compressed,
|
||||||
|
)
|
||||||
|
|
||||||
|
if events["rerecord_episode"]:
|
||||||
|
log_say("Re-record episode", play_sounds)
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
|
||||||
|
# returns to its initial joint positions captured at startup
|
||||||
|
if not teleop and self.config.reset_to_initial_position:
|
||||||
|
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
recorded_episodes += 1
|
||||||
|
finally:
|
||||||
|
# Save any frames buffered in the current episode so an unexpected
|
||||||
|
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||||
|
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||||
|
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
def _policy_loop(
|
||||||
|
self,
|
||||||
|
ctx: RolloutContext,
|
||||||
|
robot,
|
||||||
|
events: dict,
|
||||||
|
features: dict,
|
||||||
|
fps: float,
|
||||||
|
control_time_s: float,
|
||||||
|
dataset,
|
||||||
|
single_task: str,
|
||||||
|
) -> None:
|
||||||
|
"""Policy-driven recording loop for a single episode."""
|
||||||
|
interpolator = self._interpolator
|
||||||
|
control_interval = interpolator.get_control_interval(fps)
|
||||||
|
|
||||||
|
timestamp = 0.0
|
||||||
|
start_t = time.perf_counter()
|
||||||
|
|
||||||
|
while timestamp < control_time_s:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if ctx.runtime.shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||||
|
|
||||||
|
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||||
|
continue
|
||||||
|
|
||||||
|
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||||
|
|
||||||
|
if action_dict is not None:
|
||||||
|
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||||
|
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||||
|
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||||
|
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||||
|
|
||||||
|
dt = time.perf_counter() - loop_start
|
||||||
|
sleep_t = control_interval - dt
|
||||||
|
if sleep_t < 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||||
|
"Dataset frames might be dropped and robot control might be unstable. "
|
||||||
|
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||||
|
"3) CPU starvation"
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_t, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_t
|
||||||
|
|
||||||
|
def _reset_loop(
|
||||||
|
self,
|
||||||
|
ctx: RolloutContext,
|
||||||
|
robot,
|
||||||
|
teleop,
|
||||||
|
events: dict,
|
||||||
|
fps: float,
|
||||||
|
control_time_s: float,
|
||||||
|
display_data: bool,
|
||||||
|
display_compressed: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||||
|
processors = ctx.processors
|
||||||
|
control_interval = 1.0 / fps
|
||||||
|
|
||||||
|
timestamp = 0.0
|
||||||
|
start_t = time.perf_counter()
|
||||||
|
|
||||||
|
while timestamp < control_time_s:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if ctx.runtime.shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
obs = robot.get_observation()
|
||||||
|
|
||||||
|
if teleop is not None:
|
||||||
|
act = teleop.get_action()
|
||||||
|
act_teleop = processors.teleop_action_processor((act, obs))
|
||||||
|
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||||
|
robot.send_action(robot_action)
|
||||||
|
|
||||||
|
if display_data:
|
||||||
|
obs_processed = processors.robot_observation_processor(obs)
|
||||||
|
log_rerun_data(
|
||||||
|
observation=obs_processed,
|
||||||
|
action=act_teleop,
|
||||||
|
compress_images=display_compressed,
|
||||||
|
)
|
||||||
|
|
||||||
|
dt = time.perf_counter() - loop_start
|
||||||
|
sleep_t = control_interval - dt
|
||||||
|
precise_sleep(max(sleep_t, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_t
|
||||||
|
|
||||||
|
def teardown(self, ctx: RolloutContext) -> None:
|
||||||
|
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||||
|
cfg = ctx.runtime.cfg
|
||||||
|
play_sounds = cfg.play_sounds
|
||||||
|
|
||||||
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
|
|
||||||
|
if not is_headless() and self._listener is not None:
|
||||||
|
self._listener.stop()
|
||||||
|
|
||||||
|
if ctx.data.dataset is not None:
|
||||||
|
logger.info("Finalizing dataset...")
|
||||||
|
ctx.data.dataset.finalize()
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.dataset is not None
|
||||||
|
and cfg.dataset.push_to_hub
|
||||||
|
and ctx.data.dataset is not None
|
||||||
|
and safe_push_to_hub(
|
||||||
|
ctx.data.dataset,
|
||||||
|
tags=cfg.dataset.tags,
|
||||||
|
private=cfg.dataset.private,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
logger.info("Dataset uploaded to hub")
|
||||||
|
log_say("Dataset uploaded to hub", play_sounds)
|
||||||
|
|
||||||
|
self._teardown_hardware(
|
||||||
|
ctx.hardware,
|
||||||
|
return_to_initial_position=cfg.return_to_initial_position,
|
||||||
|
)
|
||||||
|
log_say("Exiting", play_sounds)
|
||||||
|
logger.info("Episodic strategy teardown complete")
|
||||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
|||||||
from .base import BaseStrategy
|
from .base import BaseStrategy
|
||||||
from .core import RolloutStrategy
|
from .core import RolloutStrategy
|
||||||
from .dagger import DAggerStrategy
|
from .dagger import DAggerStrategy
|
||||||
|
from .episodic import EpisodicStrategy
|
||||||
from .highlight import HighlightStrategy
|
from .highlight import HighlightStrategy
|
||||||
from .sentry import SentryStrategy
|
from .sentry import SentryStrategy
|
||||||
|
|
||||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
|||||||
return HighlightStrategy(config)
|
return HighlightStrategy(config)
|
||||||
if config.type == "dagger":
|
if config.type == "dagger":
|
||||||
return DAggerStrategy(config)
|
return DAggerStrategy(config)
|
||||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
if config.type == "episodic":
|
||||||
|
return EpisodicStrategy(config)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ Strategies
|
|||||||
--strategy.type=sentry Continuous recording with auto-upload
|
--strategy.type=sentry Continuous recording with auto-upload
|
||||||
--strategy.type=highlight Ring buffer + keystroke save
|
--strategy.type=highlight Ring buffer + keystroke save
|
||||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||||
|
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||||
|
|
||||||
Inference backends
|
Inference backends
|
||||||
------------------
|
------------------
|
||||||
@@ -111,6 +112,18 @@ Usage examples
|
|||||||
--display_data=true \\
|
--display_data=true \\
|
||||||
--use_torch_compile=true
|
--use_torch_compile=true
|
||||||
|
|
||||||
|
# Episodic mode — episode-oriented recording with reset phases
|
||||||
|
lerobot-rollout \\
|
||||||
|
--strategy.type=episodic \\
|
||||||
|
--policy.path=user/my_policy \\
|
||||||
|
--robot.type=so100_follower \\
|
||||||
|
--robot.port=/dev/ttyACM0 \\
|
||||||
|
--teleop.type=so100_leader \\
|
||||||
|
--teleop.port=/dev/ttyACM1 \\
|
||||||
|
--dataset.repo_id=user/rollout_episodic_data \\
|
||||||
|
--dataset.num_episodes=20 \\
|
||||||
|
--dataset.single_task="Grab the cube"
|
||||||
|
|
||||||
# Resume a previous sentry recording session
|
# Resume a previous sentry recording session
|
||||||
lerobot-rollout \\
|
lerobot-rollout \\
|
||||||
--strategy.type=sentry \\
|
--strategy.type=sentry \\
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from types import SimpleNamespace
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
@@ -191,7 +190,7 @@ class _FakeQwenInterface(nn.Module):
|
|||||||
|
|
||||||
def build_inputs(
|
def build_inputs(
|
||||||
self,
|
self,
|
||||||
images: list[list[Image.Image]],
|
images: list[list[Tensor]],
|
||||||
instructions: list[str],
|
instructions: list[str],
|
||||||
action_prompt: str,
|
action_prompt: str,
|
||||||
embodied_prompt: str,
|
embodied_prompt: str,
|
||||||
@@ -214,12 +213,13 @@ class _FakeQwenInterface(nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
def to_pixel_values(image_tensor: Tensor) -> Tensor:
|
||||||
image = image_tensor.detach().cpu()
|
image = image_tensor.detach().float()
|
||||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
if image.shape[-3] == 1:
|
||||||
image = image.permute(1, 2, 0)
|
repeats = [1] * image.ndim
|
||||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
repeats[-3] = 3
|
||||||
return Image.fromarray(image)
|
image = image.repeat(*repeats)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
class _FakeVideoEncoder(nn.Module):
|
class _FakeVideoEncoder(nn.Module):
|
||||||
@@ -242,12 +242,14 @@ class _FakeVideoEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class _FakeVideoProcessor:
|
class _FakeVideoProcessor:
|
||||||
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
def __call__(self, videos, return_tensors: str, device=None, **kwargs) -> dict[str, Tensor]:
|
||||||
assert return_tensors == "pt"
|
assert return_tensors == "pt"
|
||||||
if isinstance(videos, list):
|
if isinstance(videos, list):
|
||||||
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||||
else:
|
else:
|
||||||
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||||
|
if device is not None:
|
||||||
|
pixel_values = pixel_values.to(device)
|
||||||
return {"pixel_values_videos": pixel_values}
|
return {"pixel_values_videos": pixel_values}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -211,40 +211,42 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
|
|||||||
|
|
||||||
|
|
||||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
examples = policy._prepare_model_inputs(make_train_batch())
|
inputs = policy._prepare_model_inputs(make_train_batch())
|
||||||
|
|
||||||
assert len(examples) == BATCH_SIZE
|
assert set(inputs) >= {"images", "instructions", "videos", "actions", "state"}
|
||||||
for ex in examples:
|
# images: per-sample, per-view [C, H, W] float tensors (kept as a list for Qwen messages)
|
||||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
assert len(inputs["images"]) == BATCH_SIZE and len(inputs["images"][0]) == 1
|
||||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
img = inputs["images"][0][0]
|
||||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 and img.ndim == 3
|
||||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
assert len(inputs["instructions"]) == BATCH_SIZE
|
||||||
assert ex["state"].shape == (1, STATE_DIM)
|
# videos: batched [B, V, T, C, H, W] float
|
||||||
|
assert inputs["videos"].ndim == 6 and inputs["videos"].shape[0] == BATCH_SIZE
|
||||||
|
assert inputs["videos"].dtype == torch.float32
|
||||||
|
assert inputs["actions"].shape == (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
assert inputs["state"].shape == (BATCH_SIZE, 1, STATE_DIM)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
inputs = policy._prepare_model_inputs(make_inference_batch())
|
||||||
assert "action" not in ex
|
assert "actions" not in inputs and "action_is_pad" not in inputs
|
||||||
assert "image" in ex and "video" in ex and "lang" in ex
|
assert {"images", "instructions", "state"} <= set(inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
batch = make_inference_batch()
|
batch = make_inference_batch()
|
||||||
del batch["task"]
|
del batch["task"]
|
||||||
examples = policy._prepare_model_inputs(batch)
|
instructions = policy._prepare_model_inputs(batch)["instructions"]
|
||||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
assert all(isinstance(s, str) and len(s) > 0 for s in instructions)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
batch = make_inference_batch()
|
batch = make_inference_batch()
|
||||||
batch["task"] = "open the drawer"
|
batch["task"] = "open the drawer"
|
||||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
assert policy._prepare_model_inputs(batch)["instructions"] == ["open the drawer"] * BATCH_SIZE
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||||
@@ -253,7 +255,7 @@ def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: N
|
|||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
batch = make_inference_batch()
|
batch = make_inference_batch()
|
||||||
del batch[OBS_STATE]
|
del batch[OBS_STATE]
|
||||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
assert "state" not in policy._prepare_model_inputs(batch)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -446,14 +448,14 @@ def test_postprocessor_applied_after_predict_action_chunk(
|
|||||||
"""
|
"""
|
||||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||||
|
|
||||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
raw_actions = torch.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=torch.float32)
|
||||||
|
|
||||||
cfg = make_config()
|
cfg = make_config()
|
||||||
cfg.clip_normalized_actions = False
|
cfg.clip_normalized_actions = False
|
||||||
cfg.binarize_gripper_action = False
|
cfg.binarize_gripper_action = False
|
||||||
policy = VLAJEPAPolicy(cfg)
|
policy = VLAJEPAPolicy(cfg)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.clone())
|
||||||
|
|
||||||
dataset_stats = _make_dataset_stats()
|
dataset_stats = _make_dataset_stats()
|
||||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||||
@@ -564,9 +566,9 @@ def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_model
|
|||||||
original_processor = policy.model.video_processor
|
original_processor = policy.model.video_processor
|
||||||
|
|
||||||
class _CapturingProcessor:
|
class _CapturingProcessor:
|
||||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
|
||||||
captured_videos.extend(videos)
|
captured_videos.extend(videos)
|
||||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
|
||||||
|
|
||||||
policy.model.video_processor = _CapturingProcessor()
|
policy.model.video_processor = _CapturingProcessor()
|
||||||
policy.forward(_make_multiview_train_batch(num_views=1))
|
policy.forward(_make_multiview_train_batch(num_views=1))
|
||||||
@@ -587,9 +589,9 @@ def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: No
|
|||||||
original_processor = policy.model.video_processor
|
original_processor = policy.model.video_processor
|
||||||
|
|
||||||
class _CapturingProcessor:
|
class _CapturingProcessor:
|
||||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
def __call__(self, videos: list, return_tensors: str, **kwargs) -> dict:
|
||||||
captured_videos.extend(videos)
|
captured_videos.extend(videos)
|
||||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
return original_processor(videos=videos, return_tensors=return_tensors, **kwargs)
|
||||||
|
|
||||||
policy.model.video_processor = _CapturingProcessor()
|
policy.model.video_processor = _CapturingProcessor()
|
||||||
policy.forward(_make_multiview_train_batch(num_views=3))
|
policy.forward(_make_multiview_train_batch(num_views=3))
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from typing import Any
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class MockLazyTensorStateStep(ProcessorStep):
|
||||||
|
"""Mock step whose tensor state is not present in constructor config."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.scale = scale
|
||||||
|
self.tensor_state: torch.Tensor | None = None
|
||||||
|
|
||||||
|
if initial_value is not None:
|
||||||
|
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Return the transition unchanged."""
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return constructor config while intentionally omitting tensor state."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"scale": self.scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
"""Return tensor state only after it has been initialized or loaded."""
|
||||||
|
if self.tensor_state is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {"tensor_state": self.tensor_state}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
"""Load tensor state."""
|
||||||
|
self.tensor_state = state["tensor_state"].clone()
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""Return features unchanged."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||||
|
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||||
|
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||||
|
|
||||||
|
|
||||||
def test_empty_pipeline():
|
def test_empty_pipeline():
|
||||||
"""Test pipeline with no steps."""
|
"""Test pipeline with no steps."""
|
||||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
|||||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config_matches_saved_json():
|
||||||
|
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||||
|
stateless_step = MockStep(name="stateless")
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||||
|
|
||||||
|
in_memory_config = pipeline.get_config()
|
||||||
|
|
||||||
|
assert pipeline.get_config() == in_memory_config
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||||
|
with open(config_path) as file_pointer:
|
||||||
|
saved_config = json.load(file_pointer)
|
||||||
|
|
||||||
|
assert in_memory_config == saved_config
|
||||||
|
assert "state_file" not in in_memory_config["steps"][0]
|
||||||
|
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_matches_saved_safetensors():
|
||||||
|
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||||
|
|
||||||
|
in_memory_state_dict = pipeline.state_dict()
|
||||||
|
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||||
|
state_key = "stateful_pipeline_step_0"
|
||||||
|
|
||||||
|
assert set(in_memory_state_dict) == {state_key}
|
||||||
|
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||||
|
|
||||||
|
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||||
|
assert stateful_step.tensor_state is not None
|
||||||
|
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||||
|
|
||||||
|
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||||
|
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
save_path = Path(tmp_dir)
|
||||||
|
assert (save_path / "policy_preprocessor.json").exists()
|
||||||
|
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_round_trips_stateful_pipeline():
|
||||||
|
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert len(loaded_pipeline) == 1
|
||||||
|
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||||
|
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||||
|
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||||
|
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||||
|
|
||||||
|
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||||
|
assert config["steps"][0]["state_file"] == state_filename
|
||||||
|
assert set(pipeline_state_dict) == {state_key}
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||||
|
assert loaded_step.tensor_state is not None
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||||
|
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||||
|
assert loaded_step.state_dict() == {}
|
||||||
|
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||||
|
|
||||||
|
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||||
|
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_applies_overrides_before_state_loading():
|
||||||
|
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||||
|
config,
|
||||||
|
state_dict=pipeline_state_dict,
|
||||||
|
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||||
|
)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||||
|
assert loaded_step.scale == 5.0
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_state_dict_raises_on_missing_expected_state():
|
||||||
|
"""Test loading raises when serialized config expects missing state."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||||
|
loaded_pipeline.load_state_dict({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||||
|
"""Test loading raises on unexpected top-level state keys."""
|
||||||
|
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="extra"):
|
||||||
|
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||||
|
|
||||||
|
|
||||||
|
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||||
|
"""Test stateless in-memory serialization and loading."""
|
||||||
|
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
config_without_name = {"steps": config["steps"]}
|
||||||
|
|
||||||
|
assert pipeline.state_dict() == {}
|
||||||
|
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||||
|
|
||||||
|
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||||
|
assert loaded_pipeline.state_dict() == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||||
|
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||||
|
"""Test from_config reports invalid top-level config values cleanly."""
|
||||||
|
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||||
|
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
class MockModuleStep(ProcessorStep, nn.Module):
|
class MockModuleStep(ProcessorStep, nn.Module):
|
||||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
|||||||
from lerobot.rollout import (
|
from lerobot.rollout import (
|
||||||
BaseStrategyConfig,
|
BaseStrategyConfig,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
HighlightStrategyConfig,
|
HighlightStrategyConfig,
|
||||||
SentryStrategyConfig,
|
SentryStrategyConfig,
|
||||||
)
|
)
|
||||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
|||||||
assert SentryStrategyConfig().type == "sentry"
|
assert SentryStrategyConfig().type == "sentry"
|
||||||
assert HighlightStrategyConfig().type == "highlight"
|
assert HighlightStrategyConfig().type == "highlight"
|
||||||
assert DAggerStrategyConfig().type == "dagger"
|
assert DAggerStrategyConfig().type == "dagger"
|
||||||
|
assert EpisodicStrategyConfig().type == "episodic"
|
||||||
|
|
||||||
|
|
||||||
def test_dagger_config_invalid_input_device():
|
def test_dagger_config_invalid_input_device():
|
||||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
|||||||
BaseStrategyConfig,
|
BaseStrategyConfig,
|
||||||
DAggerStrategy,
|
DAggerStrategy,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategy,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
SentryStrategy,
|
SentryStrategy,
|
||||||
SentryStrategyConfig,
|
SentryStrategyConfig,
|
||||||
create_strategy,
|
create_strategy,
|
||||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
|||||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||||
|
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||||
|
|
||||||
|
|
||||||
def test_create_strategy_unknown_raises():
|
def test_create_strategy_unknown_raises():
|
||||||
|
|||||||
Reference in New Issue
Block a user