mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 15:17:05 +00:00
Merge branch 'feat/add_pi' into feat/validate_pi_libero
This commit is contained in:
@@ -105,6 +105,20 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim): # see lerobot pi0 `pad_vector` (exact copy)
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
@@ -175,8 +189,6 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
|
||||
@@ -491,8 +503,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(32, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
|
||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
||||
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
@@ -710,8 +722,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension (32) as expected by action_in_proj
|
||||
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
||||
# Sample noise with padded dimension as expected by action_in_proj
|
||||
actions_shape = (
|
||||
bsize,
|
||||
self.config.chunk_size,
|
||||
self.config.action_dim,
|
||||
) # Use config action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -748,10 +764,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
|
||||
# Truncate to actual action dimension before returning
|
||||
if self.config.action_dim < 32:
|
||||
x_t = x_t[:, :, : self.config.action_dim]
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
@@ -1002,7 +1014,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
}
|
||||
|
||||
def _preprocess_images(
|
||||
self, batch: dict[str, Tensor]
|
||||
self, batch: dict[str, Tensor], *, train: bool = False
|
||||
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1015,51 +1027,156 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
for key in self.config.image_keys:
|
||||
if key in batch:
|
||||
img = batch[key]
|
||||
# from lerobot pi0: Use dynamic image configuration
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
# Ensure tensor is on the same device as the model
|
||||
if img.device != device:
|
||||
img = img.to(device)
|
||||
# from lerobot pi0: Validation: Require at least one image to be present
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. "
|
||||
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
|
||||
)
|
||||
|
||||
# Ensure float32 dtype for consistency
|
||||
if img.dtype != torch.float32:
|
||||
img = img.to(torch.float32)
|
||||
# from lerobot pi0: Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
# Check if image is in [B, C, H, W] format (channels first)
|
||||
if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB
|
||||
# Already in correct format
|
||||
pass
|
||||
elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format
|
||||
# Convert to [B, C, H, W]
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
else:
|
||||
raise ValueError(f"Unexpected image shape {img.shape} for key {key}")
|
||||
# Ensure tensor is on the same device as the model
|
||||
if img.device != device:
|
||||
img = img.to(device)
|
||||
|
||||
# Resize with padding if needed
|
||||
if img.shape[-2:] != self.config.image_resolution:
|
||||
# resize_with_pad_torch handles both [B, C, H, W] and [B, H, W, C] formats
|
||||
# But we need to ensure we pass it in the right format
|
||||
img = resize_with_pad_torch(
|
||||
img.permute(0, 2, 3, 1), # Convert to [B, H, W, C] for resize function
|
||||
*self.config.image_resolution,
|
||||
).permute(0, 3, 1, 2) # Convert back to [B, C, H, W]
|
||||
# Ensure float32 dtype for consistency
|
||||
if img.dtype != torch.float32:
|
||||
img = img.to(torch.float32)
|
||||
|
||||
# Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma
|
||||
# Check if normalization is needed
|
||||
if img.min() >= 0 and img.max() <= 1:
|
||||
img = img * 2.0 - 1.0
|
||||
elif img.min() >= -1 and img.max() <= 1:
|
||||
# Already normalized to [-1, 1]
|
||||
pass
|
||||
else:
|
||||
# Assume it's in [0, 255] range and normalize
|
||||
img = (img / 255.0) * 2.0 - 1.0
|
||||
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
images.append(img)
|
||||
# Create mask (all ones for real images)
|
||||
img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device))
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Resize with padding if needed
|
||||
if img.shape[1:3] != self.config.image_resolution:
|
||||
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Training augmentations
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
img = img / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = img.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=img.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=img.device)
|
||||
img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
img = torch.nn.functional.interpolate(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=img.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=img.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
img = torch.nn.functional.grid_sample(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=img.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
img = img * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=img.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = img.mean(dim=[1, 2, 3], keepdim=True)
|
||||
img = (img - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=img.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = img.mean(dim=-1, keepdim=True)
|
||||
img = gray + (img - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
img = torch.clamp(img, 0, 1)
|
||||
|
||||
else:
|
||||
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
images.append(img)
|
||||
# from lerobot pi0: Create mask (all ones for real images)
|
||||
bsize = img.shape[0]
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
img_masks.append(mask)
|
||||
|
||||
# from lerobot pi0: Create image features not present in the batch as fully 0 padded images
|
||||
for _num_empty_cameras in range(len(missing_img_keys)):
|
||||
img = torch.ones_like(img) * -1 # from lerobot pi0: padded with -1 for SigLIP
|
||||
mask = torch.zeros_like(mask) # from lerobot pi0: mask is zero for empty cameras
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
@@ -1098,6 +1215,16 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_STATE], self.config.state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.action_dim)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action`
|
||||
"""Select a single action given environment observations."""
|
||||
@@ -1105,7 +1232,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps]
|
||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
@@ -1119,30 +1246,16 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch, train=False)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = batch[OBS_STATE]
|
||||
|
||||
# Validate state dimension
|
||||
if state.shape[-1] > 32:
|
||||
raise ValueError(
|
||||
f"State dimension {state.shape[-1]} exceeds maximum of 32. "
|
||||
f"Please reduce state dimension or modify the model."
|
||||
)
|
||||
|
||||
# Pad state to 32 dimensions if needed (PI05 expects fixed 32-dim); works similar to lerobot pi0 `prepare_state`
|
||||
if state.shape[-1] < 32:
|
||||
padding = torch.zeros(
|
||||
state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype
|
||||
)
|
||||
state = torch.cat([state, padding], dim=-1)
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
|
||||
# Truncate to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
||||
if self.config.action_dim < 32:
|
||||
actions = actions[:, :, : self.config.action_dim]
|
||||
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
@@ -1153,35 +1266,11 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch, train=True)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = batch[OBS_STATE]
|
||||
actions = batch[ACTION]
|
||||
|
||||
# Validate state and action dimensions
|
||||
if state.shape[-1] > 32:
|
||||
raise ValueError(
|
||||
f"State dimension {state.shape[-1]} exceeds maximum of 32. "
|
||||
f"Please reduce state dimension or modify the model."
|
||||
)
|
||||
if actions.shape[-1] > 32:
|
||||
raise ValueError(
|
||||
f"Action dimension {actions.shape[-1]} exceeds maximum of 32. "
|
||||
f"Please reduce action dimension or modify the model."
|
||||
)
|
||||
|
||||
# Pad state and actions to 32 dimensions if needed (PI05 expects fixed 32-dim)
|
||||
if state.shape[-1] < 32:
|
||||
padding = torch.zeros(
|
||||
state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype
|
||||
)
|
||||
state = torch.cat([state, padding], dim=-1)
|
||||
|
||||
if actions.shape[-1] < 32:
|
||||
padding = torch.zeros(
|
||||
*actions.shape[:-1], 32 - actions.shape[-1], device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, padding], dim=-1)
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
@@ -105,6 +105,20 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim): # see lerobot pi0 `pad_vector` (exact copy)
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
@@ -175,8 +189,6 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
|
||||
@@ -491,10 +503,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(32, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
|
||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
||||
|
||||
self.state_proj = nn.Linear(32, action_expert_config.width)
|
||||
self.state_proj = nn.Linear(config.state_dim, action_expert_config.width)
|
||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
@@ -727,8 +739,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension (32) as expected by action_in_proj
|
||||
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
||||
# Sample noise with padded dimension as expected by action_in_proj
|
||||
actions_shape = (
|
||||
bsize,
|
||||
self.config.chunk_size,
|
||||
self.config.action_dim,
|
||||
) # Use config action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -765,10 +781,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
|
||||
# Truncate to actual action dimension before returning
|
||||
if self.config.action_dim < 32:
|
||||
x_t = x_t[:, :, : self.config.action_dim]
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
@@ -1015,7 +1027,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
}
|
||||
|
||||
def _preprocess_images(
|
||||
self, batch: dict[str, Tensor]
|
||||
self, batch: dict[str, Tensor], *, train: bool = False
|
||||
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1028,51 +1040,156 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
for key in self.config.image_keys:
|
||||
if key in batch:
|
||||
img = batch[key]
|
||||
# from lerobot pi0: Use dynamic image configuration
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
# Ensure tensor is on the same device as the model
|
||||
if img.device != device:
|
||||
img = img.to(device)
|
||||
# from lerobot pi0: Validation: Require at least one image to be present
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. "
|
||||
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
|
||||
)
|
||||
|
||||
# Ensure float32 dtype for consistency
|
||||
if img.dtype != torch.float32:
|
||||
img = img.to(torch.float32)
|
||||
# from lerobot pi0: Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
# Check if image is in [B, C, H, W] format (channels first)
|
||||
if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB
|
||||
# Already in correct format
|
||||
pass
|
||||
elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format
|
||||
# Convert to [B, C, H, W]
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
else:
|
||||
raise ValueError(f"Unexpected image shape {img.shape} for key {key}")
|
||||
# Ensure tensor is on the same device as the model
|
||||
if img.device != device:
|
||||
img = img.to(device)
|
||||
|
||||
# Resize with padding if needed
|
||||
if img.shape[-2:] != self.config.image_resolution:
|
||||
# resize_with_pad_torch handles both [B, C, H, W] and [B, H, W, C] formats
|
||||
# But we need to ensure we pass it in the right format
|
||||
img = resize_with_pad_torch(
|
||||
img.permute(0, 2, 3, 1), # Convert to [B, H, W, C] for resize function
|
||||
*self.config.image_resolution,
|
||||
).permute(0, 3, 1, 2) # Convert back to [B, C, H, W]
|
||||
# Ensure float32 dtype for consistency
|
||||
if img.dtype != torch.float32:
|
||||
img = img.to(torch.float32)
|
||||
|
||||
# Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma
|
||||
# Check if normalization is needed
|
||||
if img.min() >= 0 and img.max() <= 1:
|
||||
img = img * 2.0 - 1.0
|
||||
elif img.min() >= -1 and img.max() <= 1:
|
||||
# Already normalized to [-1, 1]
|
||||
pass
|
||||
else:
|
||||
# Assume it's in [0, 255] range and normalize
|
||||
img = (img / 255.0) * 2.0 - 1.0
|
||||
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
images.append(img)
|
||||
# Create mask (all ones for real images)
|
||||
img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device))
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Resize with padding if needed
|
||||
if img.shape[1:3] != self.config.image_resolution:
|
||||
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Training augmentations
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
img = img / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = img.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=img.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=img.device)
|
||||
img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
img = torch.nn.functional.interpolate(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=img.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=img.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
img = torch.nn.functional.grid_sample(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=img.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
img = img * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=img.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = img.mean(dim=[1, 2, 3], keepdim=True)
|
||||
img = (img - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=img.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = img.mean(dim=-1, keepdim=True)
|
||||
img = gray + (img - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
img = torch.clamp(img, 0, 1)
|
||||
|
||||
else:
|
||||
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
images.append(img)
|
||||
# from lerobot pi0: Create mask (all ones for real images)
|
||||
bsize = img.shape[0]
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
img_masks.append(mask)
|
||||
|
||||
# from lerobot pi0: Create image features not present in the batch as fully 0 padded images
|
||||
for _num_empty_cameras in range(len(missing_img_keys)):
|
||||
img = torch.ones_like(img) * -1 # from lerobot pi0: padded with -1 for SigLIP
|
||||
mask = torch.zeros_like(mask) # from lerobot pi0: mask is zero for empty cameras
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
@@ -1111,6 +1228,16 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_STATE], self.config.state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.action_dim)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action`
|
||||
"""Select a single action given environment observations."""
|
||||
@@ -1118,7 +1245,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps]
|
||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
@@ -1132,30 +1259,16 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch, train=False)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = batch[OBS_STATE]
|
||||
|
||||
# Validate state dimension
|
||||
if state.shape[-1] > 32:
|
||||
raise ValueError(
|
||||
f"State dimension {state.shape[-1]} exceeds maximum of 32. "
|
||||
f"Please reduce state dimension or modify the model."
|
||||
)
|
||||
|
||||
# Pad state to 32 dimensions if needed (PI0 expects fixed 32-dim); works similar to lerobot pi0 `prepare_state`
|
||||
if state.shape[-1] < 32:
|
||||
padding = torch.zeros(
|
||||
state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype
|
||||
)
|
||||
state = torch.cat([state, padding], dim=-1)
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
|
||||
# Truncate to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
||||
if self.config.action_dim < 32:
|
||||
actions = actions[:, :, : self.config.action_dim]
|
||||
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
@@ -1166,42 +1279,17 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch, train=True)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = batch[OBS_STATE]
|
||||
actions = batch[ACTION]
|
||||
|
||||
# Validate state and action dimensions
|
||||
if state.shape[-1] > 32:
|
||||
raise ValueError(
|
||||
f"State dimension {state.shape[-1]} exceeds maximum of 32. "
|
||||
f"Please reduce state dimension or modify the model."
|
||||
)
|
||||
if actions.shape[-1] > 32:
|
||||
raise ValueError(
|
||||
f"Action dimension {actions.shape[-1]} exceeds maximum of 32. "
|
||||
f"Please reduce action dimension or modify the model."
|
||||
)
|
||||
|
||||
# Pad state and actions to 32 dimensions if needed (PI0 expects fixed 32-dim)
|
||||
if state.shape[-1] < 32:
|
||||
padding = torch.zeros(
|
||||
state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype
|
||||
)
|
||||
state = torch.cat([state, padding], dim=-1)
|
||||
|
||||
if actions.shape[-1] < 32:
|
||||
padding = torch.zeros(
|
||||
*actions.shape[:-1], 32 - actions.shape[-1], device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, padding], dim=-1)
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
if self.config.action_dim < 32:
|
||||
losses = losses[:, :, : self.config.action_dim]
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user