mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 00:57:06 +00:00
add video backbone to pi05
This commit is contained in:
@@ -105,6 +105,16 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
"""Return indices for delta image observations only.
|
||||
|
||||
Unlike observation_delta_indices which applies to ALL observations,
|
||||
this only applies to image observations (keys starting with observation.images).
|
||||
Default returns None. Override in subclass to enable.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -59,7 +59,12 @@ def resolve_delta_timestamps(
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
|
||||
# Check for image-specific delta indices first (e.g., for video encoding)
|
||||
if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices]
|
||||
# Fall back to generic observation delta indices for all observations
|
||||
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -35,6 +35,7 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
@@ -67,7 +68,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi05_video":
|
||||
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
|
||||
|
||||
return PI05VideoPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -147,7 +152,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
@@ -169,6 +174,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi05_video":
|
||||
return PI05VideoConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -333,6 +340,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05VideoConfig):
|
||||
from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors
|
||||
|
||||
processors = make_pi05_video_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
|
||||
@@ -17,15 +17,15 @@
|
||||
# Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config
|
||||
# when only importing subpackages like videoprism
|
||||
def __getattr__(name):
|
||||
if name == "PI05Config":
|
||||
from .configuration_pi05 import PI05Config
|
||||
return PI05Config
|
||||
elif name == "PI05Policy":
|
||||
from .modeling_pi05 import PI05Policy
|
||||
return PI05Policy
|
||||
elif name == "make_pi05_pre_post_processors":
|
||||
from .processor_pi05 import make_pi05_pre_post_processors
|
||||
return make_pi05_pre_post_processors
|
||||
if name == "PI05VideoConfig":
|
||||
from .configuration_pi05 import PI05VideoConfig
|
||||
return PI05VideoConfig
|
||||
elif name == "PI05VideoPolicy":
|
||||
from .modeling_pi05 import PI05VideoPolicy
|
||||
return PI05VideoPolicy
|
||||
elif name == "make_pi05_video_pre_post_processors":
|
||||
from .processor_pi05 import make_pi05_video_pre_post_processors
|
||||
return make_pi05_video_pre_post_processors
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
|
||||
__all__ = ["PI05VideoConfig", "PI05VideoPolicy", "make_pi05_video_pre_post_processors"]
|
||||
|
||||
@@ -26,9 +26,9 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@PreTrainedConfig.register_subclass("pi05_video")
|
||||
@dataclass
|
||||
class PI05Config(PreTrainedConfig):
|
||||
class PI05VideoConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
@@ -37,6 +37,19 @@ class PI05Config(PreTrainedConfig):
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Video encoder settings (VideoPrism)
|
||||
use_video_encoder: bool = False # Enable video encoding with VideoPrism
|
||||
video_num_frames: int = 16 # Number of frames for video encoding (VideoPrism default is 16)
|
||||
videoprism_model_name: str = "MHRDYN7/videoprism-base-f16r288" # VideoPrism model to use
|
||||
videoprism_image_size: int = 288 # VideoPrism expects 288x288 images
|
||||
freeze_video_encoder: bool = True # Whether to freeze the video encoder weights
|
||||
video_padding_mode: str = "repeat" # How to pad frames at episode start: "repeat" or "zero"
|
||||
# Which camera to use for video encoding (None = first camera, or specify key like "observation.images.top")
|
||||
video_encoder_camera_key: str | None = None
|
||||
# Perceiver Resampler settings to reduce video tokens (4096 -> video_num_latents)
|
||||
video_num_latents: int = 128 # Number of latent tokens for video resampler
|
||||
video_resampler_num_heads: int = 8 # Number of attention heads in resampler
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
@@ -115,6 +128,17 @@ class PI05Config(PreTrainedConfig):
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
# Validate video encoder settings
|
||||
if self.use_video_encoder:
|
||||
if self.video_num_frames < 1:
|
||||
raise ValueError(f"video_num_frames must be >= 1, got {self.video_num_frames}")
|
||||
if self.videoprism_image_size < 1:
|
||||
raise ValueError(f"videoprism_image_size must be >= 1, got {self.videoprism_image_size}")
|
||||
if self.video_padding_mode not in ["repeat", "zero"]:
|
||||
raise ValueError(
|
||||
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
@@ -157,7 +181,26 @@ class PI05Config(PreTrainedConfig):
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
"""Return indices for delta observations.
|
||||
|
||||
For PI05, we don't use generic observation_delta_indices because it would
|
||||
apply to both images AND state. Instead, we use image_observation_delta_indices
|
||||
which only applies to image observations.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def image_observation_delta_indices(self) -> list[int] | None:
|
||||
"""Return indices for delta image observations only.
|
||||
|
||||
When video encoding is enabled, returns indices for the past frames
|
||||
needed by VideoPrism (e.g., -15, -14, ..., -1, 0 for 16 frames).
|
||||
This only applies to image observations, not state.
|
||||
"""
|
||||
if self.use_video_encoder:
|
||||
# Return indices for past frames: [-15, -14, ..., -1, 0] for 16 frames
|
||||
return list(range(-(self.video_num_frames - 1), 1))
|
||||
return None
|
||||
|
||||
@property
|
||||
|
||||
@@ -40,8 +40,17 @@ else:
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
# VideoPrism imports for video encoding
|
||||
try:
|
||||
from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor, VideoPrismVisionModel
|
||||
_videoprism_available = True
|
||||
except ImportError:
|
||||
_videoprism_available = False
|
||||
VideoPrismVideoProcessor = None
|
||||
VideoPrismVisionModel = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.videovla.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05VideoConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -289,6 +298,60 @@ def compute_layer_complete(
|
||||
return outputs_embeds
|
||||
|
||||
|
||||
class PerceiverResampler(nn.Module):
|
||||
"""Perceiver Resampler to reduce video tokens via cross-attention.
|
||||
|
||||
This module uses learnable query tokens that cross-attend to the video tokens,
|
||||
effectively reducing the sequence length while preserving important information.
|
||||
|
||||
Args:
|
||||
dim: Hidden dimension of the input/output features
|
||||
num_latents: Number of learnable query tokens (output sequence length)
|
||||
num_heads: Number of attention heads
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 768, num_latents: int = 128, num_heads: int = 8):
|
||||
super().__init__()
|
||||
self.num_latents = num_latents
|
||||
self.dim = dim
|
||||
|
||||
# Learnable query tokens
|
||||
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
||||
|
||||
# Cross-attention layer
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
# Layer norms for queries and key-values
|
||||
self.ln_q = nn.LayerNorm(dim)
|
||||
self.ln_kv = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: Input video tokens of shape (B, N, D) where N can be large (e.g., 4096)
|
||||
|
||||
Returns:
|
||||
Resampled tokens of shape (B, num_latents, D)
|
||||
"""
|
||||
B, N, D = x.shape
|
||||
|
||||
# Expand learnable latents to batch size
|
||||
latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, num_latents, D)
|
||||
|
||||
# Apply layer norms
|
||||
q = self.ln_q(latents)
|
||||
kv = self.ln_kv(x)
|
||||
|
||||
# Cross-attention: queries attend to video tokens
|
||||
out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class GemmaConfig: # see openpi `gemma.py: Config`
|
||||
"""Configuration for Gemma model variants."""
|
||||
|
||||
@@ -534,7 +597,7 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||
def __init__(self, config: PI05VideoConfig, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
@@ -566,6 +629,47 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
# Initialize VideoPrism video encoder if enabled
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_proj = None
|
||||
self.video_resampler = None
|
||||
if config.use_video_encoder:
|
||||
if not _videoprism_available:
|
||||
raise ImportError(
|
||||
"VideoPrism is not available. Please install the required dependencies."
|
||||
)
|
||||
logging.info(f"Initializing VideoPrism video encoder: {config.videoprism_model_name}")
|
||||
self.video_processor = VideoPrismVideoProcessor.from_pretrained(config.videoprism_model_name)
|
||||
self.video_encoder = VideoPrismVisionModel.from_pretrained(
|
||||
config.videoprism_model_name,
|
||||
torch_dtype=torch.bfloat16 if config.dtype == "bfloat16" else torch.float32,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
# Get the hidden size from VideoPrism config (default is 768 for base model)
|
||||
video_hidden_size = self.video_encoder.config.hidden_size
|
||||
|
||||
# Initialize Perceiver Resampler to reduce video tokens (e.g., 4096 -> 128)
|
||||
self.video_resampler = PerceiverResampler(
|
||||
dim=video_hidden_size,
|
||||
num_latents=config.video_num_latents,
|
||||
num_heads=config.video_resampler_num_heads,
|
||||
)
|
||||
logging.info(
|
||||
f"Initialized video resampler: {video_hidden_size}D, "
|
||||
f"{config.video_num_latents} latents, {config.video_resampler_num_heads} heads"
|
||||
)
|
||||
|
||||
# Project video embeddings to PaliGemma's hidden size
|
||||
self.video_proj = nn.Linear(video_hidden_size, paligemma_config.width)
|
||||
|
||||
# Freeze video encoder if requested
|
||||
if config.freeze_video_encoder:
|
||||
self.video_encoder.eval()
|
||||
for param in self.video_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
logging.info("Video encoder weights are frozen")
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
@@ -632,13 +736,33 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, tokens, masks
|
||||
self, images, img_masks, tokens, masks, video_emb: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
||||
"""Embed images with SigLIP, optional video with VideoPrism, and language tokens with embedding layer.
|
||||
|
||||
Args:
|
||||
images: List of image tensors [B, C, H, W]
|
||||
img_masks: List of image masks [B]
|
||||
tokens: Language tokens [B, seq_len]
|
||||
masks: Language attention masks [B, seq_len]
|
||||
video_emb: Optional video embeddings from VideoPrism [B, num_video_tokens, hidden_dim]
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings, pad_masks, att_masks)
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Process video embeddings first (if available)
|
||||
if video_emb is not None:
|
||||
bsize, num_video_tokens, _ = video_emb.shape
|
||||
embs.append(video_emb)
|
||||
# Video tokens are always valid
|
||||
video_mask = torch.ones(bsize, num_video_tokens, dtype=torch.bool, device=video_emb.device)
|
||||
pad_masks.append(video_mask)
|
||||
att_masks += [0] * num_video_tokens
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
@@ -674,6 +798,69 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_video(self, video_frames: torch.Tensor) -> torch.Tensor:
|
||||
"""Embed video frames using VideoPrism encoder.
|
||||
|
||||
Args:
|
||||
video_frames: Tensor of shape [B, T, C, H, W] where T is the number of frames.
|
||||
Expected to be normalized to [0, 1].
|
||||
|
||||
Returns:
|
||||
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
|
||||
PaliGemma's hidden dimension.
|
||||
"""
|
||||
if self.video_encoder is None:
|
||||
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
|
||||
|
||||
device = video_frames.device
|
||||
dtype = video_frames.dtype
|
||||
|
||||
# Move video encoder to the same device if needed
|
||||
if next(self.video_encoder.parameters()).device != device:
|
||||
self.video_encoder = self.video_encoder.to(device)
|
||||
|
||||
# VideoPrism expects pixel values in [0, 1] range and shape [B, T, C, H, W]
|
||||
# Resize frames to VideoPrism expected size if needed
|
||||
B, T, C, H, W = video_frames.shape
|
||||
target_size = self.config.videoprism_image_size
|
||||
|
||||
if H != target_size or W != target_size:
|
||||
# Resize each frame
|
||||
video_frames = video_frames.view(B * T, C, H, W)
|
||||
video_frames = F.interpolate(
|
||||
video_frames,
|
||||
size=(target_size, target_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
video_frames = video_frames.view(B, T, C, target_size, target_size)
|
||||
|
||||
# Convert to the expected dtype for the video encoder
|
||||
video_encoder_dtype = next(self.video_encoder.parameters()).dtype
|
||||
video_frames = video_frames.to(dtype=video_encoder_dtype)
|
||||
|
||||
# Run through VideoPrism
|
||||
with torch.set_grad_enabled(not self.config.freeze_video_encoder):
|
||||
if self.config.freeze_video_encoder:
|
||||
self.video_encoder.eval()
|
||||
|
||||
video_outputs = self.video_encoder(pixel_values_videos=video_frames)
|
||||
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
|
||||
video_embeddings = video_outputs.last_hidden_state
|
||||
|
||||
# Convert to working dtype
|
||||
video_embeddings = video_embeddings.to(dtype=dtype)
|
||||
|
||||
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
|
||||
# This uses cross-attention from learnable queries to the video tokens
|
||||
video_embeddings = self.video_resampler(video_embeddings)
|
||||
# Shape: [B, num_latents, hidden_size] (e.g., [B, 128, 768])
|
||||
|
||||
# Project to PaliGemma's hidden dimension
|
||||
video_embeddings = self.video_proj(video_embeddings)
|
||||
|
||||
return video_embeddings
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
@@ -721,8 +908,21 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
def forward(
|
||||
self, images, img_masks, tokens, masks, actions, noise=None, time=None, video_frames=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss.
|
||||
|
||||
Args:
|
||||
images: List of image tensors [B, C, H, W]
|
||||
img_masks: List of image masks [B]
|
||||
tokens: Language tokens [B, seq_len]
|
||||
masks: Language attention masks [B, seq_len]
|
||||
actions: Ground truth actions [B, chunk_size, action_dim]
|
||||
noise: Optional noise tensor for flow matching
|
||||
time: Optional time tensor for flow matching
|
||||
video_frames: Optional video frames [B, T, C, H, W] for video encoding
|
||||
"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
@@ -733,7 +933,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
# Embed video if provided and video encoder is available
|
||||
video_emb = None
|
||||
if video_frames is not None and self.video_encoder is not None:
|
||||
video_emb = self.embed_video(video_frames)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, tokens, masks, video_emb=video_emb
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
if (
|
||||
@@ -785,9 +992,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
video_frames=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
"""Do a full inference forward and compute the action.
|
||||
|
||||
Args:
|
||||
images: List of image tensors [B, C, H, W]
|
||||
img_masks: List of image masks [B]
|
||||
tokens: Language tokens [B, seq_len]
|
||||
masks: Language attention masks [B, seq_len]
|
||||
noise: Optional noise tensor
|
||||
num_steps: Number of denoising steps
|
||||
video_frames: Optional video frames [B, T, C, H, W] for video encoding
|
||||
"""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
|
||||
@@ -803,7 +1021,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
# Embed video if provided and video encoder is available
|
||||
video_emb = None
|
||||
if video_frames is not None and self.video_encoder is not None:
|
||||
video_emb = self.embed_video(video_frames)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, tokens, masks, video_emb=video_emb
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
@@ -895,15 +1120,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
|
||||
class PI05Policy(PreTrainedPolicy):
|
||||
"""PI05 Policy for LeRobot."""
|
||||
class PI05VideoPolicy(PreTrainedPolicy):
|
||||
"""PI05 Video Policy for LeRobot with optional video encoding support."""
|
||||
|
||||
config_class = PI05Config
|
||||
name = "pi05"
|
||||
config_class = PI05VideoConfig
|
||||
name = "pi05_video"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI05Config,
|
||||
config: PI05VideoConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -1128,11 +1353,33 @@ class PI05Policy(PreTrainedPolicy):
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _get_video_camera_key(self) -> str | None:
|
||||
"""Get the camera key to use for video encoding.
|
||||
|
||||
Returns the configured video_encoder_camera_key if set,
|
||||
otherwise returns the first image feature key.
|
||||
"""
|
||||
if not self.config.use_video_encoder:
|
||||
return None
|
||||
|
||||
if self.config.video_encoder_camera_key is not None:
|
||||
return self.config.video_encoder_camera_key
|
||||
|
||||
# Default to first image feature (image_features is a dict)
|
||||
if self.config.image_features:
|
||||
return next(iter(self.config.image_features.keys()))
|
||||
|
||||
return None
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
|
||||
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
|
||||
|
||||
When video encoding is enabled:
|
||||
- The video camera is skipped (processed separately by video encoder)
|
||||
- Other cameras with temporal dimension have only the current frame extracted
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
@@ -1140,10 +1387,17 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Determine which camera is used for video encoding (to skip it)
|
||||
video_camera_key = self._get_video_camera_key()
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
# Filter out the video camera key if video encoding is enabled
|
||||
if video_camera_key is not None and video_camera_key in present_img_keys:
|
||||
present_img_keys = [k for k in present_img_keys if k != video_camera_key]
|
||||
|
||||
if len(present_img_keys) == 0 and video_camera_key is None:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. "
|
||||
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
|
||||
@@ -1161,6 +1415,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
if img.dtype != torch.float32:
|
||||
img = img.to(torch.float32)
|
||||
|
||||
# Handle temporal dimension: if [B, T, C, H, W], extract current frame (last one)
|
||||
if img.ndim == 5:
|
||||
# Extract the last frame (current observation at index -1)
|
||||
img = img[:, -1] # [B, T, C, H, W] -> [B, C, H, W]
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
@@ -1187,13 +1446,99 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Create image features not present in the batch as fully 0 padded images
|
||||
for _num_empty_cameras in range(len(missing_img_keys)):
|
||||
img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
|
||||
mask = torch.zeros_like(mask) # Mask is zero for empty cameras
|
||||
if len(images) > 0:
|
||||
img = torch.ones_like(images[-1]) * -1 # Padded with -1 for SigLIP
|
||||
mask = torch.zeros_like(img_masks[-1]) # Mask is zero for empty cameras
|
||||
else:
|
||||
# No images processed yet, create placeholder
|
||||
bsize = next(iter(batch.values())).shape[0]
|
||||
img = torch.ones(
|
||||
bsize, 3, *self.config.image_resolution, dtype=torch.float32, device=device
|
||||
) * -1
|
||||
mask = torch.zeros(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def _preprocess_video(self, batch: dict[str, Tensor]) -> Tensor | None:
|
||||
"""Preprocess video frames for the video encoder.
|
||||
|
||||
When image_observation_delta_indices is set (for video encoding), the batch will contain
|
||||
images with shape [B, T, C, H, W] where T is the number of frames.
|
||||
This method extracts and preprocesses these frames for VideoPrism.
|
||||
|
||||
Handles frame padding at episode start when fewer than video_num_frames are available:
|
||||
- "repeat": Repeat the first available frame to fill missing frames
|
||||
- "zero": Use zero-padded frames for missing frames
|
||||
|
||||
Args:
|
||||
batch: Training batch potentially containing multi-frame observations.
|
||||
|
||||
Returns:
|
||||
Video frames tensor of shape [B, T, C, H, W] normalized to [0, 1],
|
||||
or None if video encoding is not enabled.
|
||||
"""
|
||||
if not self.config.use_video_encoder:
|
||||
return None
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Get the video camera key
|
||||
video_camera_key = self._get_video_camera_key()
|
||||
if video_camera_key is None or video_camera_key not in batch:
|
||||
return None
|
||||
|
||||
img = batch[video_camera_key]
|
||||
|
||||
# Check if we have temporal dimension (video frames)
|
||||
if img.ndim == 4:
|
||||
# Single frame [B, C, H, W] - expand to video by repeating
|
||||
B, C, H, W = img.shape
|
||||
if self.config.video_padding_mode == "repeat":
|
||||
video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W)
|
||||
else: # zero padding
|
||||
video_frames = torch.zeros(
|
||||
B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device
|
||||
)
|
||||
video_frames[:, -1] = img # Put current frame at the end
|
||||
elif img.ndim == 5:
|
||||
# Multiple frames [B, T, C, H, W]
|
||||
video_frames = img
|
||||
B, T, C, H, W = video_frames.shape
|
||||
|
||||
# Handle case where we have fewer frames than expected (episode start)
|
||||
if T < self.config.video_num_frames:
|
||||
num_missing = self.config.video_num_frames - T
|
||||
|
||||
if self.config.video_padding_mode == "repeat":
|
||||
# Repeat the first frame to fill missing frames at the beginning
|
||||
first_frame = video_frames[:, 0:1] # [B, 1, C, H, W]
|
||||
padding = first_frame.expand(B, num_missing, C, H, W)
|
||||
video_frames = torch.cat([padding, video_frames], dim=1)
|
||||
else: # zero padding
|
||||
# Zero-pad at the beginning
|
||||
padding = torch.zeros(
|
||||
B, num_missing, C, H, W, dtype=video_frames.dtype, device=video_frames.device
|
||||
)
|
||||
video_frames = torch.cat([padding, video_frames], dim=1)
|
||||
else:
|
||||
logging.warning(f"Unexpected image shape for video camera: {img.shape}")
|
||||
return None
|
||||
|
||||
# Ensure tensor is on the same device
|
||||
if video_frames.device != device:
|
||||
video_frames = video_frames.to(device)
|
||||
|
||||
# Ensure float32 dtype
|
||||
if video_frames.dtype != torch.float32:
|
||||
video_frames = video_frames.to(torch.float32)
|
||||
|
||||
# Video frames should be in [0, 1] range for VideoPrism
|
||||
# LeRobot images are already in [0, 1] range
|
||||
|
||||
return video_frames
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
@@ -1225,8 +1570,13 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Preprocess video frames if video encoding is enabled
|
||||
video_frames = self._preprocess_video(batch)
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, tokens, masks, video_frames=video_frames, **kwargs
|
||||
)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1249,8 +1599,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Preprocess video frames if video encoding is enabled
|
||||
video_frames = self._preprocess_video(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, video_frames=video_frames)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -97,15 +97,15 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_pre_post_processors(
|
||||
config: PI05Config,
|
||||
def make_pi05_video_pre_post_processors(
|
||||
config: PI05VideoConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
Constructs pre-processor and post-processor pipelines for the PI05Video policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for PI05 with video encoder (VideoPrism).
|
||||
|
||||
This script creates a dummy example to test the model with video encoding enabled.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def create_dummy_batch(
|
||||
batch_size: int = 2,
|
||||
num_frames: int = 16,
|
||||
image_size: int = 224,
|
||||
num_cameras: int = 2,
|
||||
state_dim: int = 14,
|
||||
action_dim: int = 14,
|
||||
chunk_size: int = 50,
|
||||
seq_len: int = 10,
|
||||
device: str = "cuda",
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create a dummy batch for testing."""
|
||||
batch = {}
|
||||
|
||||
# Create image observations with temporal dimension [B, T, C, H, W]
|
||||
for i in range(num_cameras):
|
||||
key = f"{OBS_IMAGES}.camera_{i}"
|
||||
# Images in [0, 1] range
|
||||
batch[key] = torch.rand(batch_size, num_frames, 3, image_size, image_size, device=device)
|
||||
|
||||
# Create state observation [B, state_dim]
|
||||
batch[OBS_STATE] = torch.rand(batch_size, state_dim, device=device)
|
||||
|
||||
# Create language tokens and attention mask [B, seq_len]
|
||||
batch["observation.language.tokens"] = torch.randint(0, 1000, (batch_size, seq_len), device=device)
|
||||
batch["observation.language.attention_mask"] = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Create action targets [B, chunk_size, action_dim]
|
||||
batch[ACTION] = torch.rand(batch_size, chunk_size, action_dim, device=device)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def test_video_encoder():
|
||||
"""Test the PI05 model with video encoding enabled."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Configuration
|
||||
batch_size = 2
|
||||
num_frames = 16
|
||||
image_size = 224
|
||||
num_cameras = 2
|
||||
state_dim = 14
|
||||
action_dim = 14
|
||||
chunk_size = 50
|
||||
|
||||
# Create config with video encoder enabled
|
||||
print("Creating PI05VideoConfig with video encoder...")
|
||||
config = PI05VideoConfig(
|
||||
use_video_encoder=True,
|
||||
video_num_frames=num_frames,
|
||||
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
|
||||
videoprism_image_size=288,
|
||||
freeze_video_encoder=True,
|
||||
video_padding_mode="repeat",
|
||||
video_encoder_camera_key=f"{OBS_IMAGES}.camera_0", # Use first camera for video
|
||||
chunk_size=chunk_size,
|
||||
max_action_dim=32,
|
||||
max_state_dim=32,
|
||||
dtype="float32", # Use float32 for testing
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set up input/output features
|
||||
for i in range(num_cameras):
|
||||
key = f"{OBS_IMAGES}.camera_{i}"
|
||||
config.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, image_size, image_size),
|
||||
)
|
||||
|
||||
config.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(state_dim,),
|
||||
)
|
||||
|
||||
config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(action_dim,),
|
||||
)
|
||||
|
||||
print(f"use_video_encoder: {config.use_video_encoder}")
|
||||
print(f"video_num_frames: {config.video_num_frames}")
|
||||
print(f"video_padding_mode: {config.video_padding_mode}")
|
||||
print(f"video_encoder_camera_key: {config.video_encoder_camera_key}")
|
||||
print(f"image_observation_delta_indices: {config.image_observation_delta_indices}")
|
||||
|
||||
# Create model
|
||||
model = PI05VideoPolicy(config)
|
||||
model.to(device)
|
||||
|
||||
# Create dummy batch
|
||||
batch = create_dummy_batch(
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
image_size=image_size,
|
||||
num_cameras=num_cameras,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
chunk_size=chunk_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
print(f"Batch keys: {list(batch.keys())}" )
|
||||
for key, value in batch.items():
|
||||
print(f"{key}: {value.shape}")
|
||||
|
||||
# Test forward pass
|
||||
model.train()
|
||||
try:
|
||||
loss, loss_dict = model.forward(batch)
|
||||
print(f"Forward pass successful!")
|
||||
print(f"Loss: {loss.item():.4f}")
|
||||
print(f"Loss dict: {loss_dict}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
# Test inference
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
try:
|
||||
actions = model.predict_action_chunk(batch)
|
||||
print(f"Test pass, inference pass!")
|
||||
print(f"Predicted actions shape: {actions.shape}")
|
||||
except Exception as e:
|
||||
print(f"Inference failed: {e}")
|
||||
raise
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
def test_frame_padding():
|
||||
"""Test frame padding at episode start."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create config
|
||||
config = PI05VideoConfig(
|
||||
use_video_encoder=True,
|
||||
video_num_frames=16,
|
||||
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
|
||||
freeze_video_encoder=True,
|
||||
video_padding_mode="repeat",
|
||||
chunk_size=50,
|
||||
dtype="float32",
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set up minimal features
|
||||
config.input_features[f"{OBS_IMAGES}.camera_0"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
)
|
||||
config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(14,),
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = PI05VideoPolicy(config)
|
||||
model.to(device)
|
||||
|
||||
# Test with fewer frames than expected (simulating episode start)
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.camera_0": torch.rand(2, 5, 3, 224, 224, device=device),
|
||||
"observation.language.tokens": torch.randint(0, 1000, (2, 10), device=device),
|
||||
"observation.language.attention_mask": torch.ones(2, 10, dtype=torch.bool, device=device),
|
||||
ACTION: torch.rand(2, 50, 14, device=device),
|
||||
}
|
||||
|
||||
video_frames = model._preprocess_video(batch)
|
||||
if video_frames is not None:
|
||||
print(f"Input frames: 5")
|
||||
print(f"Output video_frames shape: {video_frames.shape}")
|
||||
print(f"Expected: [2, 16, 3, 224, 224]")
|
||||
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
|
||||
print("Frame padding test PASSED!")
|
||||
else:
|
||||
print("video_frames is None (unexpected)")
|
||||
|
||||
# Test with single frame
|
||||
batch[f"{OBS_IMAGES}.camera_0"] = torch.rand(2, 3, 224, 224, device=device) # [B, C, H, W]
|
||||
|
||||
video_frames = model._preprocess_video(batch)
|
||||
if video_frames is not None:
|
||||
print(f"Input: single frame [B, C, H, W]")
|
||||
print(f"Output video_frames shape: {video_frames.shape}")
|
||||
print(f"Expected: [2, 16, 3, 224, 224]")
|
||||
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
|
||||
print("Single frame expansion test PASSED!")
|
||||
else:
|
||||
print("video_frames is None (unexpected)")
|
||||
|
||||
print("All tests passed!")
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
test_frame_padding()
|
||||
test_video_encoder()
|
||||
Reference in New Issue
Block a user