diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index a912deac2..39281204e 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -105,6 +105,20 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` ( return att_2d_masks & pad_2d_masks +def pad_vector(vector, new_dim): # see lerobot pi0 `pad_vector` (exact copy) + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) images: torch.Tensor, height: int, @@ -175,8 +189,6 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - if batch_size == 1 and images.shape[0] == 1: - padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added return padded_images @@ -491,8 +503,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` precision=config.dtype, ) - self.action_in_proj = nn.Linear(32, action_expert_config.width) - self.action_out_proj = nn.Linear(action_expert_config.width, 32) + self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) @@ -710,8 +722,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` device = state.device if noise is None: - # Sample noise with padded dimension (32) as expected by action_in_proj - actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.action_dim, + ) # Use config action_dim for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( @@ -748,10 +764,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = x_t + dt * v_t time += dt - # Truncate to actual action dimension before returning - if self.config.action_dim < 32: - x_t = x_t[:, :, : self.config.action_dim] - return x_t def denoise_step( @@ -1002,7 +1014,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): } def _preprocess_images( - self, batch: dict[str, Tensor] + self, batch: dict[str, Tensor], *, train: bool = False ) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images` """Preprocess images for the model. @@ -1015,51 +1027,156 @@ class PI05OpenPIPolicy(PreTrainedPolicy): # Get device from model parameters device = next(self.parameters()).device - for key in self.config.image_keys: - if key in batch: - img = batch[key] + # from lerobot pi0: Use dynamic image configuration + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] - # Ensure tensor is on the same device as the model - if img.device != device: - img = img.to(device) + # from lerobot pi0: Validation: Require at least one image to be present + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) - # Ensure float32 dtype for consistency - if img.dtype != torch.float32: - img = img.to(torch.float32) + # from lerobot pi0: Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] - # Check if image is in [B, C, H, W] format (channels first) - if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB - # Already in correct format - pass - elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format - # Convert to [B, C, H, W] - img = img.permute(0, 3, 1, 2) - else: - raise ValueError(f"Unexpected image shape {img.shape} for key {key}") + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) - # Resize with padding if needed - if img.shape[-2:] != self.config.image_resolution: - # resize_with_pad_torch handles both [B, C, H, W] and [B, H, W, C] formats - # But we need to ensure we pass it in the right format - img = resize_with_pad_torch( - img.permute(0, 2, 3, 1), # Convert to [B, H, W, C] for resize function - *self.config.image_resolution, - ).permute(0, 3, 1, 2) # Convert back to [B, C, H, W] + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) - # Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma - # Check if normalization is needed - if img.min() >= 0 and img.max() <= 1: - img = img * 2.0 - 1.0 - elif img.min() >= -1 and img.max() <= 1: - # Already normalized to [-1, 1] - pass - else: - # Assume it's in [0, 255] range and normalize - img = (img / 255.0) * 2.0 - 1.0 + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 - images.append(img) - # Create mask (all ones for real images) - img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device)) + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # from openpi preprocess_observation_pytorch: Training augmentations + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + img = img / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = img.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=img.device) + start_w = torch.randint(0, max_w + 1, (1,), device=img.device) + img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] + + # Resize back to original size + img = torch.nn.functional.interpolate( + img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode="bilinear", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=img.device) + grid_y = torch.linspace(-1, 1, height, device=img.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + img = torch.nn.functional.grid_sample( + img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = ( + 0.7 + torch.rand(1, device=img.device) * 0.6 + ) # Random factor between 0.7 and 1.3 + img = img * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = ( + 0.6 + torch.rand(1, device=img.device) * 0.8 + ) # Random factor between 0.6 and 1.4 + mean = img.mean(dim=[1, 2, 3], keepdim=True) + img = (img - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = ( + 0.5 + torch.rand(1, device=img.device) * 1.0 + ) # Random factor between 0.5 and 1.5 + gray = img.mean(dim=-1, keepdim=True) + img = gray + (img - gray) * saturation_factor + + # Clamp values to [0, 1] + img = torch.clamp(img, 0, 1) + + else: + # from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # from lerobot pi0: Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # from lerobot pi0: Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # from lerobot pi0: padded with -1 for SigLIP + mask = torch.zeros_like(mask) # from lerobot pi0: mask is zero for empty cameras + images.append(img) + img_masks.append(mask) return images, img_masks @@ -1098,6 +1215,16 @@ class PI05OpenPIPolicy(PreTrainedPolicy): return lang_tokens, lang_masks + def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy) + """Pad state""" + state = pad_vector(batch[OBS_STATE], self.config.state_dim) + return state + + def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.action_dim) + return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action` """Select a single action given environment observations.""" @@ -1105,7 +1232,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) @@ -1119,30 +1246,16 @@ class PI05OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch, train=False) lang_tokens, lang_masks = self._tokenize_language(batch) - state = batch[OBS_STATE] - - # Validate state dimension - if state.shape[-1] > 32: - raise ValueError( - f"State dimension {state.shape[-1]} exceeds maximum of 32. " - f"Please reduce state dimension or modify the model." - ) - - # Pad state to 32 dimensions if needed (PI05 expects fixed 32-dim); works similar to lerobot pi0 `prepare_state` - if state.shape[-1] < 32: - padding = torch.zeros( - state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype - ) - state = torch.cat([state, padding], dim=-1) + state = self.prepare_state(batch) # Sample actions using the model actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) - # Truncate to actual action dimension, works similar to lerobot pi0 `prepare_action` - if self.config.action_dim < 32: - actions = actions[:, :, : self.config.action_dim] + # Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action` + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions @@ -1153,35 +1266,11 @@ class PI05OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch, train=True) lang_tokens, lang_masks = self._tokenize_language(batch) - state = batch[OBS_STATE] - actions = batch[ACTION] - # Validate state and action dimensions - if state.shape[-1] > 32: - raise ValueError( - f"State dimension {state.shape[-1]} exceeds maximum of 32. " - f"Please reduce state dimension or modify the model." - ) - if actions.shape[-1] > 32: - raise ValueError( - f"Action dimension {actions.shape[-1]} exceeds maximum of 32. " - f"Please reduce action dimension or modify the model." - ) - - # Pad state and actions to 32 dimensions if needed (PI05 expects fixed 32-dim) - if state.shape[-1] < 32: - padding = torch.zeros( - state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype - ) - state = torch.cat([state, padding], dim=-1) - - if actions.shape[-1] < 32: - padding = torch.zeros( - *actions.shape[:-1], 32 - actions.shape[-1], device=actions.device, dtype=actions.dtype - ) - actions = torch.cat([actions, padding], dim=-1) + state = self.prepare_state(batch) + actions = self.prepare_action(batch) # Compute loss losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 44d1d6a50..35c6f7c9a 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -105,6 +105,20 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` ( return att_2d_masks & pad_2d_masks +def pad_vector(vector, new_dim): # see lerobot pi0 `pad_vector` (exact copy) + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) images: torch.Tensor, height: int, @@ -175,8 +189,6 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - if batch_size == 1 and images.shape[0] == 1: - padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added return padded_images @@ -491,10 +503,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` precision=config.dtype, ) - self.action_in_proj = nn.Linear(32, action_expert_config.width) - self.action_out_proj = nn.Linear(action_expert_config.width, 32) + self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) - self.state_proj = nn.Linear(32, action_expert_config.width) + self.state_proj = nn.Linear(config.state_dim, action_expert_config.width) self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) @@ -727,8 +739,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` device = state.device if noise is None: - # Sample noise with padded dimension (32) as expected by action_in_proj - actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.action_dim, + ) # Use config action_dim for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( @@ -765,10 +781,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = x_t + dt * v_t time += dt - # Truncate to actual action dimension before returning - if self.config.action_dim < 32: - x_t = x_t[:, :, : self.config.action_dim] - return x_t def denoise_step( @@ -1015,7 +1027,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): } def _preprocess_images( - self, batch: dict[str, Tensor] + self, batch: dict[str, Tensor], *, train: bool = False ) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images` """Preprocess images for the model. @@ -1028,51 +1040,156 @@ class PI0OpenPIPolicy(PreTrainedPolicy): # Get device from model parameters device = next(self.parameters()).device - for key in self.config.image_keys: - if key in batch: - img = batch[key] + # from lerobot pi0: Use dynamic image configuration + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] - # Ensure tensor is on the same device as the model - if img.device != device: - img = img.to(device) + # from lerobot pi0: Validation: Require at least one image to be present + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) - # Ensure float32 dtype for consistency - if img.dtype != torch.float32: - img = img.to(torch.float32) + # from lerobot pi0: Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] - # Check if image is in [B, C, H, W] format (channels first) - if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB - # Already in correct format - pass - elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format - # Convert to [B, C, H, W] - img = img.permute(0, 3, 1, 2) - else: - raise ValueError(f"Unexpected image shape {img.shape} for key {key}") + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) - # Resize with padding if needed - if img.shape[-2:] != self.config.image_resolution: - # resize_with_pad_torch handles both [B, C, H, W] and [B, H, W, C] formats - # But we need to ensure we pass it in the right format - img = resize_with_pad_torch( - img.permute(0, 2, 3, 1), # Convert to [B, H, W, C] for resize function - *self.config.image_resolution, - ).permute(0, 3, 1, 2) # Convert back to [B, C, H, W] + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) - # Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma - # Check if normalization is needed - if img.min() >= 0 and img.max() <= 1: - img = img * 2.0 - 1.0 - elif img.min() >= -1 and img.max() <= 1: - # Already normalized to [-1, 1] - pass - else: - # Assume it's in [0, 255] range and normalize - img = (img / 255.0) * 2.0 - 1.0 + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 - images.append(img) - # Create mask (all ones for real images) - img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device)) + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # from openpi preprocess_observation_pytorch: Training augmentations + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + img = img / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = img.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=img.device) + start_w = torch.randint(0, max_w + 1, (1,), device=img.device) + img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] + + # Resize back to original size + img = torch.nn.functional.interpolate( + img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode="bilinear", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=img.device) + grid_y = torch.linspace(-1, 1, height, device=img.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + img = torch.nn.functional.grid_sample( + img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = ( + 0.7 + torch.rand(1, device=img.device) * 0.6 + ) # Random factor between 0.7 and 1.3 + img = img * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = ( + 0.6 + torch.rand(1, device=img.device) * 0.8 + ) # Random factor between 0.6 and 1.4 + mean = img.mean(dim=[1, 2, 3], keepdim=True) + img = (img - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = ( + 0.5 + torch.rand(1, device=img.device) * 1.0 + ) # Random factor between 0.5 and 1.5 + gray = img.mean(dim=-1, keepdim=True) + img = gray + (img - gray) * saturation_factor + + # Clamp values to [0, 1] + img = torch.clamp(img, 0, 1) + + else: + # from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # from lerobot pi0: Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # from lerobot pi0: Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # from lerobot pi0: padded with -1 for SigLIP + mask = torch.zeros_like(mask) # from lerobot pi0: mask is zero for empty cameras + images.append(img) + img_masks.append(mask) return images, img_masks @@ -1111,6 +1228,16 @@ class PI0OpenPIPolicy(PreTrainedPolicy): return lang_tokens, lang_masks + def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy) + """Pad state""" + state = pad_vector(batch[OBS_STATE], self.config.state_dim) + return state + + def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.action_dim) + return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action` """Select a single action given environment observations.""" @@ -1118,7 +1245,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) @@ -1132,30 +1259,16 @@ class PI0OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch, train=False) lang_tokens, lang_masks = self._tokenize_language(batch) - state = batch[OBS_STATE] - - # Validate state dimension - if state.shape[-1] > 32: - raise ValueError( - f"State dimension {state.shape[-1]} exceeds maximum of 32. " - f"Please reduce state dimension or modify the model." - ) - - # Pad state to 32 dimensions if needed (PI0 expects fixed 32-dim); works similar to lerobot pi0 `prepare_state` - if state.shape[-1] < 32: - padding = torch.zeros( - state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype - ) - state = torch.cat([state, padding], dim=-1) + state = self.prepare_state(batch) # Sample actions using the model actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) - # Truncate to actual action dimension, works similar to lerobot pi0 `prepare_action` - if self.config.action_dim < 32: - actions = actions[:, :, : self.config.action_dim] + # Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action` + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions @@ -1166,42 +1279,17 @@ class PI0OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch, train=True) lang_tokens, lang_masks = self._tokenize_language(batch) - state = batch[OBS_STATE] - actions = batch[ACTION] - - # Validate state and action dimensions - if state.shape[-1] > 32: - raise ValueError( - f"State dimension {state.shape[-1]} exceeds maximum of 32. " - f"Please reduce state dimension or modify the model." - ) - if actions.shape[-1] > 32: - raise ValueError( - f"Action dimension {actions.shape[-1]} exceeds maximum of 32. " - f"Please reduce action dimension or modify the model." - ) - - # Pad state and actions to 32 dimensions if needed (PI0 expects fixed 32-dim) - if state.shape[-1] < 32: - padding = torch.zeros( - state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype - ) - state = torch.cat([state, padding], dim=-1) - - if actions.shape[-1] < 32: - padding = torch.zeros( - *actions.shape[:-1], 32 - actions.shape[-1], device=actions.device, dtype=actions.dtype - ) - actions = torch.cat([actions, padding], dim=-1) + state = self.prepare_state(batch) + actions = self.prepare_action(batch) # Compute loss losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) # Truncate losses to actual action dimensions - if self.config.action_dim < 32: - losses = losses[:, :, : self.config.action_dim] + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] loss = losses.mean()