diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7f326b70b..c04376a70 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -105,6 +105,16 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation raise NotImplementedError + @property + def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg] + """Return indices for delta image observations only. + + Unlike observation_delta_indices which applies to ALL observations, + this only applies to image observations (keys starting with observation.images). + Default returns None. Override in subclass to enable. + """ + return None + @property @abc.abstractmethod def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 31e939809..0be585720 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -59,7 +59,12 @@ def resolve_delta_timestamps( delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == ACTION and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] - if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: + + # Check for image-specific delta indices first (e.g., for video encoding) + if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices] + # Fall back to generic observation delta indices for all observations + elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] if len(delta_timestamps) == 0: diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a593e5bcb..18e8e2865 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -35,6 +35,7 @@ from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig @@ -67,7 +68,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". + "vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.pi05.modeling_pi05 import PI05Policy return PI05Policy + elif name == "pi05_video": + from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy + + return PI05VideoPolicy elif name == "sac": from lerobot.policies.sac.modeling_sac import SACPolicy @@ -147,7 +152,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", + "diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla", "reward_classifier", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. @@ -169,6 +174,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi05": return PI05Config(**kwargs) + elif policy_type == "pi05_video": + return PI05VideoConfig(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) elif policy_type == "smolvla": @@ -333,6 +340,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, PI05VideoConfig): + from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors + + processors = make_pi05_video_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, SACConfig): from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors diff --git a/src/lerobot/policies/videovla/__init__.py b/src/lerobot/policies/videovla/__init__.py index a8580913c..87be157e5 100644 --- a/src/lerobot/policies/videovla/__init__.py +++ b/src/lerobot/policies/videovla/__init__.py @@ -17,15 +17,15 @@ # Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config # when only importing subpackages like videoprism def __getattr__(name): - if name == "PI05Config": - from .configuration_pi05 import PI05Config - return PI05Config - elif name == "PI05Policy": - from .modeling_pi05 import PI05Policy - return PI05Policy - elif name == "make_pi05_pre_post_processors": - from .processor_pi05 import make_pi05_pre_post_processors - return make_pi05_pre_post_processors + if name == "PI05VideoConfig": + from .configuration_pi05 import PI05VideoConfig + return PI05VideoConfig + elif name == "PI05VideoPolicy": + from .modeling_pi05 import PI05VideoPolicy + return PI05VideoPolicy + elif name == "make_pi05_video_pre_post_processors": + from .processor_pi05 import make_pi05_video_pre_post_processors + return make_pi05_video_pre_post_processors raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] +__all__ = ["PI05VideoConfig", "PI05VideoPolicy", "make_pi05_video_pre_post_processors"] diff --git a/src/lerobot/policies/videovla/configuration_pi05.py b/src/lerobot/policies/videovla/configuration_pi05.py index b96e6d196..541db40e6 100644 --- a/src/lerobot/policies/videovla/configuration_pi05.py +++ b/src/lerobot/policies/videovla/configuration_pi05.py @@ -26,9 +26,9 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 -@PreTrainedConfig.register_subclass("pi05") +@PreTrainedConfig.register_subclass("pi05_video") @dataclass -class PI05Config(PreTrainedConfig): +class PI05VideoConfig(PreTrainedConfig): paligemma_variant: str = "gemma_2b" action_expert_variant: str = "gemma_300m" dtype: str = "float32" # Options: "bfloat16", "float32" @@ -37,6 +37,19 @@ class PI05Config(PreTrainedConfig): chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" n_action_steps: int = 50 # Number of action steps to execute + # Video encoder settings (VideoPrism) + use_video_encoder: bool = False # Enable video encoding with VideoPrism + video_num_frames: int = 16 # Number of frames for video encoding (VideoPrism default is 16) + videoprism_model_name: str = "MHRDYN7/videoprism-base-f16r288" # VideoPrism model to use + videoprism_image_size: int = 288 # VideoPrism expects 288x288 images + freeze_video_encoder: bool = True # Whether to freeze the video encoder weights + video_padding_mode: str = "repeat" # How to pad frames at episode start: "repeat" or "zero" + # Which camera to use for video encoding (None = first camera, or specify key like "observation.images.top") + video_encoder_camera_key: str | None = None + # Perceiver Resampler settings to reduce video tokens (4096 -> video_num_latents) + video_num_latents: int = 128 # Number of latent tokens for video resampler + video_resampler_num_heads: int = 8 # Number of attention heads in resampler + # Shorter state and action vectors will be padded to these dimensions max_state_dim: int = 32 max_action_dim: int = 32 @@ -115,6 +128,17 @@ class PI05Config(PreTrainedConfig): if self.dtype not in ["bfloat16", "float32"]: raise ValueError(f"Invalid dtype: {self.dtype}") + # Validate video encoder settings + if self.use_video_encoder: + if self.video_num_frames < 1: + raise ValueError(f"video_num_frames must be >= 1, got {self.video_num_frames}") + if self.videoprism_image_size < 1: + raise ValueError(f"videoprism_image_size must be >= 1, got {self.videoprism_image_size}") + if self.video_padding_mode not in ["repeat", "zero"]: + raise ValueError( + f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}" + ) + def validate_features(self) -> None: """Validate and set up input/output features.""" for i in range(self.empty_cameras): @@ -157,7 +181,26 @@ class PI05Config(PreTrainedConfig): ) @property - def observation_delta_indices(self) -> None: + def observation_delta_indices(self) -> list[int] | None: + """Return indices for delta observations. + + For PI05, we don't use generic observation_delta_indices because it would + apply to both images AND state. Instead, we use image_observation_delta_indices + which only applies to image observations. + """ + return None + + @property + def image_observation_delta_indices(self) -> list[int] | None: + """Return indices for delta image observations only. + + When video encoding is enabled, returns indices for the past frames + needed by VideoPrism (e.g., -15, -14, ..., -1, 0 for 16 frames). + This only applies to image observations, not state. + """ + if self.use_video_encoder: + # Return indices for past frames: [-15, -14, ..., -1, 0] for 16 frames + return list(range(-(self.video_num_frames - 1), 1)) return None @property diff --git a/src/lerobot/policies/videovla/modeling_pi05.py b/src/lerobot/policies/videovla/modeling_pi05.py index 08bfbb98c..d3d992411 100644 --- a/src/lerobot/policies/videovla/modeling_pi05.py +++ b/src/lerobot/policies/videovla/modeling_pi05.py @@ -40,8 +40,17 @@ else: GemmaForCausalLM = None PaliGemmaForConditionalGeneration = None +# VideoPrism imports for video encoding +try: + from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor, VideoPrismVisionModel + _videoprism_available = True +except ImportError: + _videoprism_available = False + VideoPrismVideoProcessor = None + VideoPrismVisionModel = None + from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config +from lerobot.policies.videovla.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05VideoConfig from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( @@ -289,6 +298,60 @@ def compute_layer_complete( return outputs_embeds +class PerceiverResampler(nn.Module): + """Perceiver Resampler to reduce video tokens via cross-attention. + + This module uses learnable query tokens that cross-attend to the video tokens, + effectively reducing the sequence length while preserving important information. + + Args: + dim: Hidden dimension of the input/output features + num_latents: Number of learnable query tokens (output sequence length) + num_heads: Number of attention heads + """ + + def __init__(self, dim: int = 768, num_latents: int = 128, num_heads: int = 8): + super().__init__() + self.num_latents = num_latents + self.dim = dim + + # Learnable query tokens + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + + # Cross-attention layer + self.attn = nn.MultiheadAttention( + embed_dim=dim, + num_heads=num_heads, + batch_first=True, + ) + + # Layer norms for queries and key-values + self.ln_q = nn.LayerNorm(dim) + self.ln_kv = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input video tokens of shape (B, N, D) where N can be large (e.g., 4096) + + Returns: + Resampled tokens of shape (B, num_latents, D) + """ + B, N, D = x.shape + + # Expand learnable latents to batch size + latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, num_latents, D) + + # Apply layer norms + q = self.ln_q(latents) + kv = self.ln_kv(x) + + # Cross-attention: queries attend to video tokens + out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D) + + return out + + class GemmaConfig: # see openpi `gemma.py: Config` """Configuration for Gemma model variants.""" @@ -534,7 +597,7 @@ class PaliGemmaWithExpertModel( class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI05 PyTorch model.""" - def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): + def __init__(self, config: PI05VideoConfig, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config self.rtc_processor = rtc_processor @@ -566,6 +629,47 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False + # Initialize VideoPrism video encoder if enabled + self.video_encoder = None + self.video_processor = None + self.video_proj = None + self.video_resampler = None + if config.use_video_encoder: + if not _videoprism_available: + raise ImportError( + "VideoPrism is not available. Please install the required dependencies." + ) + logging.info(f"Initializing VideoPrism video encoder: {config.videoprism_model_name}") + self.video_processor = VideoPrismVideoProcessor.from_pretrained(config.videoprism_model_name) + self.video_encoder = VideoPrismVisionModel.from_pretrained( + config.videoprism_model_name, + torch_dtype=torch.bfloat16 if config.dtype == "bfloat16" else torch.float32, + attn_implementation="sdpa", + ) + # Get the hidden size from VideoPrism config (default is 768 for base model) + video_hidden_size = self.video_encoder.config.hidden_size + + # Initialize Perceiver Resampler to reduce video tokens (e.g., 4096 -> 128) + self.video_resampler = PerceiverResampler( + dim=video_hidden_size, + num_latents=config.video_num_latents, + num_heads=config.video_resampler_num_heads, + ) + logging.info( + f"Initialized video resampler: {video_hidden_size}D, " + f"{config.video_num_latents} latents, {config.video_resampler_num_heads} heads" + ) + + # Project video embeddings to PaliGemma's hidden size + self.video_proj = nn.Linear(video_hidden_size, paligemma_config.width) + + # Freeze video encoder if requested + if config.freeze_video_encoder: + self.video_encoder.eval() + for param in self.video_encoder.parameters(): + param.requires_grad = False + logging.info("Video encoder weights are frozen") + # Compile model if requested if config.compile_model: torch.set_float32_matmul_precision("high") @@ -632,13 +736,33 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return time.to(dtype=torch.float32, device=device) def embed_prefix( - self, images, img_masks, tokens, masks + self, images, img_masks, tokens, masks, video_emb: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer.""" + """Embed images with SigLIP, optional video with VideoPrism, and language tokens with embedding layer. + + Args: + images: List of image tensors [B, C, H, W] + img_masks: List of image masks [B] + tokens: Language tokens [B, seq_len] + masks: Language attention masks [B, seq_len] + video_emb: Optional video embeddings from VideoPrism [B, num_video_tokens, hidden_dim] + + Returns: + Tuple of (embeddings, pad_masks, att_masks) + """ embs = [] pad_masks = [] att_masks = [] + # Process video embeddings first (if available) + if video_emb is not None: + bsize, num_video_tokens, _ = video_emb.shape + embs.append(video_emb) + # Video tokens are always valid + video_mask = torch.ones(bsize, num_video_tokens, dtype=torch.bool, device=video_emb.device) + pad_masks.append(video_mask) + att_masks += [0] * num_video_tokens + # Process images for img, img_mask in zip(images, img_masks, strict=True): @@ -674,6 +798,69 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks + def embed_video(self, video_frames: torch.Tensor) -> torch.Tensor: + """Embed video frames using VideoPrism encoder. + + Args: + video_frames: Tensor of shape [B, T, C, H, W] where T is the number of frames. + Expected to be normalized to [0, 1]. + + Returns: + Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to + PaliGemma's hidden dimension. + """ + if self.video_encoder is None: + raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.") + + device = video_frames.device + dtype = video_frames.dtype + + # Move video encoder to the same device if needed + if next(self.video_encoder.parameters()).device != device: + self.video_encoder = self.video_encoder.to(device) + + # VideoPrism expects pixel values in [0, 1] range and shape [B, T, C, H, W] + # Resize frames to VideoPrism expected size if needed + B, T, C, H, W = video_frames.shape + target_size = self.config.videoprism_image_size + + if H != target_size or W != target_size: + # Resize each frame + video_frames = video_frames.view(B * T, C, H, W) + video_frames = F.interpolate( + video_frames, + size=(target_size, target_size), + mode="bilinear", + align_corners=False, + ) + video_frames = video_frames.view(B, T, C, target_size, target_size) + + # Convert to the expected dtype for the video encoder + video_encoder_dtype = next(self.video_encoder.parameters()).dtype + video_frames = video_frames.to(dtype=video_encoder_dtype) + + # Run through VideoPrism + with torch.set_grad_enabled(not self.config.freeze_video_encoder): + if self.config.freeze_video_encoder: + self.video_encoder.eval() + + video_outputs = self.video_encoder(pixel_values_videos=video_frames) + # Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768]) + video_embeddings = video_outputs.last_hidden_state + + # Convert to working dtype + video_embeddings = video_embeddings.to(dtype=dtype) + + # Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128) + # This uses cross-attention from learnable queries to the video tokens + video_embeddings = self.video_resampler(video_embeddings) + # Shape: [B, num_latents, hidden_size] (e.g., [B, 128, 768]) + + # Project to PaliGemma's hidden dimension + video_embeddings = self.video_proj(video_embeddings) + + return video_embeddings + def embed_suffix(self, noisy_actions, timestep): """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" embs = [] @@ -721,8 +908,21 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: - """Do a full training forward pass and compute the loss.""" + def forward( + self, images, img_masks, tokens, masks, actions, noise=None, time=None, video_frames=None + ) -> Tensor: + """Do a full training forward pass and compute the loss. + + Args: + images: List of image tensors [B, C, H, W] + img_masks: List of image masks [B] + tokens: Language tokens [B, seq_len] + masks: Language attention masks [B, seq_len] + actions: Ground truth actions [B, chunk_size, action_dim] + noise: Optional noise tensor for flow matching + time: Optional time tensor for flow matching + video_frames: Optional video frames [B, T, C, H, W] for video encoding + """ if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -733,7 +933,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + # Embed video if provided and video encoder is available + video_emb = None + if video_frames is not None and self.video_encoder is not None: + video_emb = self.embed_video(video_frames) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, tokens, masks, video_emb=video_emb + ) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) if ( @@ -785,9 +992,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` masks, noise=None, num_steps=None, + video_frames=None, **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: - """Do a full inference forward and compute the action.""" + """Do a full inference forward and compute the action. + + Args: + images: List of image tensors [B, C, H, W] + img_masks: List of image masks [B] + tokens: Language tokens [B, seq_len] + masks: Language attention masks [B, seq_len] + noise: Optional noise tensor + num_steps: Number of denoising steps + video_frames: Optional video frames [B, T, C, H, W] for video encoding + """ if num_steps is None: num_steps = self.config.num_inference_steps @@ -803,7 +1021,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + # Embed video if provided and video encoder is available + video_emb = None + if video_frames is not None and self.video_encoder is not None: + video_emb = self.embed_video(video_frames) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, tokens, masks, video_emb=video_emb + ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -895,15 +1120,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return self.action_out_proj(suffix_out) -class PI05Policy(PreTrainedPolicy): - """PI05 Policy for LeRobot.""" +class PI05VideoPolicy(PreTrainedPolicy): + """PI05 Video Policy for LeRobot with optional video encoding support.""" - config_class = PI05Config - name = "pi05" + config_class = PI05VideoConfig + name = "pi05_video" def __init__( self, - config: PI05Config, + config: PI05VideoConfig, **kwargs, ): """ @@ -1128,11 +1353,33 @@ class PI05Policy(PreTrainedPolicy): def _rtc_enabled(self) -> bool: return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _get_video_camera_key(self) -> str | None: + """Get the camera key to use for video encoding. + + Returns the configured video_encoder_camera_key if set, + otherwise returns the first image feature key. + """ + if not self.config.use_video_encoder: + return None + + if self.config.video_encoder_camera_key is not None: + return self.config.video_encoder_camera_key + + # Default to first image feature (image_features is a dict) + if self.config.image_features: + return next(iter(self.config.image_features.keys())) + + return None + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Preprocess images for the model. Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + + When video encoding is enabled: + - The video camera is skipped (processed separately by video encoder) + - Other cameras with temporal dimension have only the current frame extracted """ images = [] img_masks = [] @@ -1140,10 +1387,17 @@ class PI05Policy(PreTrainedPolicy): # Get device from model parameters device = next(self.parameters()).device + # Determine which camera is used for video encoding (to skip it) + video_camera_key = self._get_video_camera_key() + present_img_keys = [key for key in self.config.image_features if key in batch] missing_img_keys = [key for key in self.config.image_features if key not in batch] - if len(present_img_keys) == 0: + # Filter out the video camera key if video encoding is enabled + if video_camera_key is not None and video_camera_key in present_img_keys: + present_img_keys = [k for k in present_img_keys if k != video_camera_key] + + if len(present_img_keys) == 0 and video_camera_key is None: raise ValueError( f"All image features are missing from the batch. At least one expected. " f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" @@ -1161,6 +1415,11 @@ class PI05Policy(PreTrainedPolicy): if img.dtype != torch.float32: img = img.to(torch.float32) + # Handle temporal dimension: if [B, T, C, H, W], extract current frame (last one) + if img.ndim == 5: + # Extract the last frame (current observation at index -1) + img = img[:, -1] # [B, T, C, H, W] -> [B, C, H, W] + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 @@ -1187,13 +1446,99 @@ class PI05Policy(PreTrainedPolicy): # Create image features not present in the batch as fully 0 padded images for _num_empty_cameras in range(len(missing_img_keys)): - img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP - mask = torch.zeros_like(mask) # Mask is zero for empty cameras + if len(images) > 0: + img = torch.ones_like(images[-1]) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(img_masks[-1]) # Mask is zero for empty cameras + else: + # No images processed yet, create placeholder + bsize = next(iter(batch.values())).shape[0] + img = torch.ones( + bsize, 3, *self.config.image_resolution, dtype=torch.float32, device=device + ) * -1 + mask = torch.zeros(bsize, dtype=torch.bool, device=device) images.append(img) img_masks.append(mask) return images, img_masks + def _preprocess_video(self, batch: dict[str, Tensor]) -> Tensor | None: + """Preprocess video frames for the video encoder. + + When image_observation_delta_indices is set (for video encoding), the batch will contain + images with shape [B, T, C, H, W] where T is the number of frames. + This method extracts and preprocesses these frames for VideoPrism. + + Handles frame padding at episode start when fewer than video_num_frames are available: + - "repeat": Repeat the first available frame to fill missing frames + - "zero": Use zero-padded frames for missing frames + + Args: + batch: Training batch potentially containing multi-frame observations. + + Returns: + Video frames tensor of shape [B, T, C, H, W] normalized to [0, 1], + or None if video encoding is not enabled. + """ + if not self.config.use_video_encoder: + return None + + device = next(self.parameters()).device + + # Get the video camera key + video_camera_key = self._get_video_camera_key() + if video_camera_key is None or video_camera_key not in batch: + return None + + img = batch[video_camera_key] + + # Check if we have temporal dimension (video frames) + if img.ndim == 4: + # Single frame [B, C, H, W] - expand to video by repeating + B, C, H, W = img.shape + if self.config.video_padding_mode == "repeat": + video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W) + else: # zero padding + video_frames = torch.zeros( + B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device + ) + video_frames[:, -1] = img # Put current frame at the end + elif img.ndim == 5: + # Multiple frames [B, T, C, H, W] + video_frames = img + B, T, C, H, W = video_frames.shape + + # Handle case where we have fewer frames than expected (episode start) + if T < self.config.video_num_frames: + num_missing = self.config.video_num_frames - T + + if self.config.video_padding_mode == "repeat": + # Repeat the first frame to fill missing frames at the beginning + first_frame = video_frames[:, 0:1] # [B, 1, C, H, W] + padding = first_frame.expand(B, num_missing, C, H, W) + video_frames = torch.cat([padding, video_frames], dim=1) + else: # zero padding + # Zero-pad at the beginning + padding = torch.zeros( + B, num_missing, C, H, W, dtype=video_frames.dtype, device=video_frames.device + ) + video_frames = torch.cat([padding, video_frames], dim=1) + else: + logging.warning(f"Unexpected image shape for video camera: {img.shape}") + return None + + # Ensure tensor is on the same device + if video_frames.device != device: + video_frames = video_frames.to(device) + + # Ensure float32 dtype + if video_frames.dtype != torch.float32: + video_frames = video_frames.to(torch.float32) + + # Video frames should be in [0, 1] range for VideoPrism + # LeRobot images are already in [0, 1] range + + return video_frames + def prepare_action(self, batch): """Pad action""" actions = pad_vector(batch[ACTION], self.config.max_action_dim) @@ -1225,8 +1570,13 @@ class PI05Policy(PreTrainedPolicy): images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + # Preprocess video frames if video encoding is enabled + video_frames = self._preprocess_video(batch) + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) - actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + actions = self.model.sample_actions( + images, img_masks, tokens, masks, video_frames=video_frames, **kwargs + ) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] @@ -1249,8 +1599,11 @@ class PI05Policy(PreTrainedPolicy): actions = self.prepare_action(batch) + # Preprocess video frames if video encoding is enabled + video_frames = self._preprocess_video(batch) + # Compute loss (no separate state needed for PI05) - losses = self.model.forward(images, img_masks, tokens, masks, actions) + losses = self.model.forward(images, img_masks, tokens, masks, actions, video_frames=video_frames) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] diff --git a/src/lerobot/policies/videovla/processor_pi05.py b/src/lerobot/policies/videovla/processor_pi05.py index e29bc4c23..0ef7ad6a2 100644 --- a/src/lerobot/policies/videovla/processor_pi05.py +++ b/src/lerobot/policies/videovla/processor_pi05.py @@ -22,7 +22,7 @@ import numpy as np import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig from lerobot.policies.pi05.modeling_pi05 import pad_vector from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -97,15 +97,15 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): return features -def make_pi05_pre_post_processors( - config: PI05Config, +def make_pi05_video_pre_post_processors( + config: PI05VideoConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: """ - Constructs pre-processor and post-processor pipelines for the PI0 policy. + Constructs pre-processor and post-processor pipelines for the PI05Video policy. The pre-processing pipeline prepares input data for the model by: 1. Renaming features to match pretrained configurations. diff --git a/src/lerobot/policies/videovla/test_video_encoder.py b/src/lerobot/policies/videovla/test_video_encoder.py new file mode 100644 index 000000000..16f9a123e --- /dev/null +++ b/src/lerobot/policies/videovla/test_video_encoder.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python +""" +Test script for PI05 with video encoder (VideoPrism). + +This script creates a dummy example to test the model with video encoding enabled. +""" + +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig +from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +def create_dummy_batch( + batch_size: int = 2, + num_frames: int = 16, + image_size: int = 224, + num_cameras: int = 2, + state_dim: int = 14, + action_dim: int = 14, + chunk_size: int = 50, + seq_len: int = 10, + device: str = "cuda", +) -> dict[str, torch.Tensor]: + """Create a dummy batch for testing.""" + batch = {} + + # Create image observations with temporal dimension [B, T, C, H, W] + for i in range(num_cameras): + key = f"{OBS_IMAGES}.camera_{i}" + # Images in [0, 1] range + batch[key] = torch.rand(batch_size, num_frames, 3, image_size, image_size, device=device) + + # Create state observation [B, state_dim] + batch[OBS_STATE] = torch.rand(batch_size, state_dim, device=device) + + # Create language tokens and attention mask [B, seq_len] + batch["observation.language.tokens"] = torch.randint(0, 1000, (batch_size, seq_len), device=device) + batch["observation.language.attention_mask"] = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) + + # Create action targets [B, chunk_size, action_dim] + batch[ACTION] = torch.rand(batch_size, chunk_size, action_dim, device=device) + + return batch + + +def test_video_encoder(): + """Test the PI05 model with video encoding enabled.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # Configuration + batch_size = 2 + num_frames = 16 + image_size = 224 + num_cameras = 2 + state_dim = 14 + action_dim = 14 + chunk_size = 50 + + # Create config with video encoder enabled + print("Creating PI05VideoConfig with video encoder...") + config = PI05VideoConfig( + use_video_encoder=True, + video_num_frames=num_frames, + videoprism_model_name="MHRDYN7/videoprism-base-f16r288", + videoprism_image_size=288, + freeze_video_encoder=True, + video_padding_mode="repeat", + video_encoder_camera_key=f"{OBS_IMAGES}.camera_0", # Use first camera for video + chunk_size=chunk_size, + max_action_dim=32, + max_state_dim=32, + dtype="float32", # Use float32 for testing + device=device, + ) + + # Set up input/output features + for i in range(num_cameras): + key = f"{OBS_IMAGES}.camera_{i}" + config.input_features[key] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, image_size, image_size), + ) + + config.input_features[OBS_STATE] = PolicyFeature( + type=FeatureType.STATE, + shape=(state_dim,), + ) + + config.output_features[ACTION] = PolicyFeature( + type=FeatureType.ACTION, + shape=(action_dim,), + ) + + print(f"use_video_encoder: {config.use_video_encoder}") + print(f"video_num_frames: {config.video_num_frames}") + print(f"video_padding_mode: {config.video_padding_mode}") + print(f"video_encoder_camera_key: {config.video_encoder_camera_key}") + print(f"image_observation_delta_indices: {config.image_observation_delta_indices}") + + # Create model + model = PI05VideoPolicy(config) + model.to(device) + + # Create dummy batch + batch = create_dummy_batch( + batch_size=batch_size, + num_frames=num_frames, + image_size=image_size, + num_cameras=num_cameras, + state_dim=state_dim, + action_dim=action_dim, + chunk_size=chunk_size, + device=device, + ) + + print(f"Batch keys: {list(batch.keys())}" ) + for key, value in batch.items(): + print(f"{key}: {value.shape}") + + # Test forward pass + model.train() + try: + loss, loss_dict = model.forward(batch) + print(f"Forward pass successful!") + print(f"Loss: {loss.item():.4f}") + print(f"Loss dict: {loss_dict}") + except Exception as e: + print(f"Forward pass failed: {e}") + raise + + # Test inference + model.eval() + with torch.no_grad(): + try: + actions = model.predict_action_chunk(batch) + print(f"Test pass, inference pass!") + print(f"Predicted actions shape: {actions.shape}") + except Exception as e: + print(f"Inference failed: {e}") + raise + + print("All tests passed!") + + +def test_frame_padding(): + """Test frame padding at episode start.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Create config + config = PI05VideoConfig( + use_video_encoder=True, + video_num_frames=16, + videoprism_model_name="MHRDYN7/videoprism-base-f16r288", + freeze_video_encoder=True, + video_padding_mode="repeat", + chunk_size=50, + dtype="float32", + device=device, + ) + + # Set up minimal features + config.input_features[f"{OBS_IMAGES}.camera_0"] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ) + config.output_features[ACTION] = PolicyFeature( + type=FeatureType.ACTION, + shape=(14,), + ) + + # Create model + model = PI05VideoPolicy(config) + model.to(device) + + # Test with fewer frames than expected (simulating episode start) + batch = { + f"{OBS_IMAGES}.camera_0": torch.rand(2, 5, 3, 224, 224, device=device), + "observation.language.tokens": torch.randint(0, 1000, (2, 10), device=device), + "observation.language.attention_mask": torch.ones(2, 10, dtype=torch.bool, device=device), + ACTION: torch.rand(2, 50, 14, device=device), + } + + video_frames = model._preprocess_video(batch) + if video_frames is not None: + print(f"Input frames: 5") + print(f"Output video_frames shape: {video_frames.shape}") + print(f"Expected: [2, 16, 3, 224, 224]") + assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}" + print("Frame padding test PASSED!") + else: + print("video_frames is None (unexpected)") + + # Test with single frame + batch[f"{OBS_IMAGES}.camera_0"] = torch.rand(2, 3, 224, 224, device=device) # [B, C, H, W] + + video_frames = model._preprocess_video(batch) + if video_frames is not None: + print(f"Input: single frame [B, C, H, W]") + print(f"Output video_frames shape: {video_frames.shape}") + print(f"Expected: [2, 16, 3, 224, 224]") + assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}" + print("Single frame expansion test PASSED!") + else: + print("video_frames is None (unexpected)") + + print("All tests passed!") +if __name__ == "__main__": + # Run tests + test_frame_padding() + test_video_encoder()