mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
b74a551d38
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
|
|
|
|
def resize_with_pad_torch(
|
|
images: torch.Tensor,
|
|
height: int,
|
|
width: int,
|
|
mode: str = "bilinear",
|
|
) -> torch.Tensor:
|
|
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
|
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
|
|
|
Args:
|
|
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
|
height: Target height
|
|
width: Target width
|
|
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
|
|
|
Returns:
|
|
Resized and padded tensor with same shape format as input
|
|
"""
|
|
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
|
if images.shape[-1] <= 4: # Assume channels-last format
|
|
channels_last = True
|
|
# Convert to channels-first for torch operations
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
|
else:
|
|
channels_last = False
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
|
|
batch_size, channels, cur_height, cur_width = images.shape
|
|
|
|
# Calculate resize ratio
|
|
ratio = max(cur_width / width, cur_height / height)
|
|
resized_height = int(cur_height / ratio)
|
|
resized_width = int(cur_width / ratio)
|
|
|
|
# Resize
|
|
resized_images = F.interpolate(
|
|
images,
|
|
size=(resized_height, resized_width),
|
|
mode=mode,
|
|
align_corners=False if mode == "bilinear" else None,
|
|
)
|
|
|
|
# Handle dtype-specific clipping
|
|
if images.dtype == torch.uint8:
|
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
|
elif images.dtype == torch.float32:
|
|
resized_images = resized_images.clamp(-1.0, 1.0)
|
|
else:
|
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
|
|
# Calculate padding
|
|
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
|
pad_h1 = pad_h0 + remainder_h
|
|
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
|
pad_w1 = pad_w0 + remainder_w
|
|
|
|
# Pad
|
|
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
|
padded_images = F.pad(
|
|
resized_images,
|
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
|
mode="constant",
|
|
value=constant_value,
|
|
)
|
|
|
|
# Convert back to original format if needed
|
|
if channels_last:
|
|
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
if batch_size == 1 and images.shape[0] == 1:
|
|
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
|
|
|
return padded_images
|