add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0

This commit is contained in:
Pepijn
2025-09-12 11:12:47 +02:00
parent 1785767e61
commit dbe3406a69
2 changed files with 290 additions and 96 deletions
@@ -1014,7 +1014,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
}
def _preprocess_images(
self, batch: dict[str, Tensor]
self, batch: dict[str, Tensor], *, train: bool = False
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
"""Preprocess images for the model.
@@ -1027,59 +1027,156 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Get device from model parameters
device = next(self.parameters()).device
for key in self.config.image_keys:
if key in batch:
img = batch[key]
# from lerobot pi0: Use dynamic image configuration
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
# Ensure tensor is on the same device as the model
if img.device != device:
img = img.to(device)
# from lerobot pi0: Validation: Require at least one image to be present
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. "
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
)
# Ensure float32 dtype for consistency
if img.dtype != torch.float32:
img = img.to(torch.float32)
# from lerobot pi0: Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
# Check if image is in [B, C, H, W] format (channels first)
if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB
# Already in correct format
pass
elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format
# Convert to [B, C, H, W]
img = img.permute(0, 3, 1, 2)
else:
raise ValueError(f"Unexpected image shape {img.shape} for key {key}")
# Ensure tensor is on the same device as the model
if img.device != device:
img = img.to(device)
# Resize with padding if needed
if img.shape[-2:] != self.config.image_resolution:
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
# Ensure float32 dtype for consistency
if img.dtype != torch.float32:
img = img.to(torch.float32)
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
img = img.permute(0, 2, 3, 1)
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
if img.shape[1:3] != self.config.image_resolution:
img = resize_with_pad_torch(img, *self.config.image_resolution)
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
img = img.permute(0, 2, 3, 1)
# Convert back to [B, C, H, W] if we started with channels-first
if is_channels_first:
img = img.permute(0, 3, 1, 2)
# from openpi preprocess_observation_pytorch: Resize with padding if needed
if img.shape[1:3] != self.config.image_resolution:
img = resize_with_pad_torch(img, *self.config.image_resolution)
# Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma
# Check if normalization is needed
if img.min() >= 0 and img.max() <= 1:
img = img * 2.0 - 1.0
elif img.min() >= -1 and img.max() <= 1:
# Already normalized to [-1, 1]
pass
else:
# Assume it's in [0, 255] range and normalize
img = (img / 255.0) * 2.0 - 1.0
# from openpi preprocess_observation_pytorch: Training augmentations
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
img = img / 2.0 + 0.5
images.append(img)
# Create mask (all ones for real images)
img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device))
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = img.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=img.device)
start_w = torch.randint(0, max_w + 1, (1,), device=img.device)
img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
img = torch.nn.functional.interpolate(
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=img.device)
grid_y = torch.linspace(-1, 1, height, device=img.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
img = torch.nn.functional.grid_sample(
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = (
0.7 + torch.rand(1, device=img.device) * 0.6
) # Random factor between 0.7 and 1.3
img = img * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = (
0.6 + torch.rand(1, device=img.device) * 0.8
) # Random factor between 0.6 and 1.4
mean = img.mean(dim=[1, 2, 3], keepdim=True)
img = (img - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = (
0.5 + torch.rand(1, device=img.device) * 1.0
) # Random factor between 0.5 and 1.5
gray = img.mean(dim=-1, keepdim=True)
img = gray + (img - gray) * saturation_factor
# Clamp values to [0, 1]
img = torch.clamp(img, 0, 1)
else:
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
images.append(img)
# from lerobot pi0: Create mask (all ones for real images)
bsize = img.shape[0]
mask = torch.ones(bsize, dtype=torch.bool, device=device)
img_masks.append(mask)
# from lerobot pi0: Create image features not present in the batch as fully 0 padded images
for _num_empty_cameras in range(len(missing_img_keys)):
img = torch.ones_like(img) * -1 # from lerobot pi0: padded with -1 for SigLIP
mask = torch.zeros_like(mask) # from lerobot pi0: mask is zero for empty cameras
images.append(img)
img_masks.append(mask)
return images, img_masks
@@ -1135,7 +1232,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))
@@ -1149,7 +1246,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch, train=False)
lang_tokens, lang_masks = self._tokenize_language(batch)
state = self.prepare_state(batch)
@@ -1169,7 +1266,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
batch = self.normalize_targets(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch, train=True)
lang_tokens, lang_masks = self._tokenize_language(batch)
state = self.prepare_state(batch)
@@ -1027,7 +1027,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
}
def _preprocess_images(
self, batch: dict[str, Tensor]
self, batch: dict[str, Tensor], *, train: bool = False
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
"""Preprocess images for the model.
@@ -1040,59 +1040,156 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
# Get device from model parameters
device = next(self.parameters()).device
for key in self.config.image_keys:
if key in batch:
img = batch[key]
# from lerobot pi0: Use dynamic image configuration
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
# Ensure tensor is on the same device as the model
if img.device != device:
img = img.to(device)
# from lerobot pi0: Validation: Require at least one image to be present
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. "
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
)
# Ensure float32 dtype for consistency
if img.dtype != torch.float32:
img = img.to(torch.float32)
# from lerobot pi0: Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
# Check if image is in [B, C, H, W] format (channels first)
if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB
# Already in correct format
pass
elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format
# Convert to [B, C, H, W]
img = img.permute(0, 3, 1, 2)
else:
raise ValueError(f"Unexpected image shape {img.shape} for key {key}")
# Ensure tensor is on the same device as the model
if img.device != device:
img = img.to(device)
# Resize with padding if needed
if img.shape[-2:] != self.config.image_resolution:
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
# Ensure float32 dtype for consistency
if img.dtype != torch.float32:
img = img.to(torch.float32)
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
img = img.permute(0, 2, 3, 1)
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
if img.shape[1:3] != self.config.image_resolution:
img = resize_with_pad_torch(img, *self.config.image_resolution)
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
img = img.permute(0, 2, 3, 1)
# Convert back to [B, C, H, W] if we started with channels-first
if is_channels_first:
img = img.permute(0, 3, 1, 2)
# from openpi preprocess_observation_pytorch: Resize with padding if needed
if img.shape[1:3] != self.config.image_resolution:
img = resize_with_pad_torch(img, *self.config.image_resolution)
# Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma
# Check if normalization is needed
if img.min() >= 0 and img.max() <= 1:
img = img * 2.0 - 1.0
elif img.min() >= -1 and img.max() <= 1:
# Already normalized to [-1, 1]
pass
else:
# Assume it's in [0, 255] range and normalize
img = (img / 255.0) * 2.0 - 1.0
# from openpi preprocess_observation_pytorch: Training augmentations
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
img = img / 2.0 + 0.5
images.append(img)
# Create mask (all ones for real images)
img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device))
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = img.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=img.device)
start_w = torch.randint(0, max_w + 1, (1,), device=img.device)
img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
img = torch.nn.functional.interpolate(
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=img.device)
grid_y = torch.linspace(-1, 1, height, device=img.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
img = torch.nn.functional.grid_sample(
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = (
0.7 + torch.rand(1, device=img.device) * 0.6
) # Random factor between 0.7 and 1.3
img = img * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = (
0.6 + torch.rand(1, device=img.device) * 0.8
) # Random factor between 0.6 and 1.4
mean = img.mean(dim=[1, 2, 3], keepdim=True)
img = (img - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = (
0.5 + torch.rand(1, device=img.device) * 1.0
) # Random factor between 0.5 and 1.5
gray = img.mean(dim=-1, keepdim=True)
img = gray + (img - gray) * saturation_factor
# Clamp values to [0, 1]
img = torch.clamp(img, 0, 1)
else:
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
images.append(img)
# from lerobot pi0: Create mask (all ones for real images)
bsize = img.shape[0]
mask = torch.ones(bsize, dtype=torch.bool, device=device)
img_masks.append(mask)
# from lerobot pi0: Create image features not present in the batch as fully 0 padded images
for _num_empty_cameras in range(len(missing_img_keys)):
img = torch.ones_like(img) * -1 # from lerobot pi0: padded with -1 for SigLIP
mask = torch.zeros_like(mask) # from lerobot pi0: mask is zero for empty cameras
images.append(img)
img_masks.append(mask)
return images, img_masks
@@ -1148,7 +1245,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))
@@ -1162,7 +1259,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch, train=False)
lang_tokens, lang_masks = self._tokenize_language(batch)
state = self.prepare_state(batch)
@@ -1182,7 +1279,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
batch = self.normalize_targets(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch, train=True)
lang_tokens, lang_masks = self._tokenize_language(batch)
state = self.prepare_state(batch)
actions = self.prepare_action(batch)