|
|
|
@@ -23,8 +23,10 @@ from typing import TYPE_CHECKING, Any
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import torchvision.transforms.v2.functional as tv_functional
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from torchvision.transforms import InterpolationMode
|
|
|
|
|
|
|
|
|
|
from lerobot.utils.import_utils import _transformers_available
|
|
|
|
|
|
|
|
|
@@ -57,11 +59,14 @@ from lerobot.utils.constants import (
|
|
|
|
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
|
|
|
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
|
|
|
|
)
|
|
|
|
|
from lerobot.utils.device_utils import get_safe_torch_device
|
|
|
|
|
|
|
|
|
|
from .configuration_groot import (
|
|
|
|
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
|
|
|
|
GROOT_N1_5_REMOVAL_GUIDANCE,
|
|
|
|
|
GROOT_N1_7_BACKBONE_MODEL,
|
|
|
|
|
N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
|
|
|
|
N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
|
|
|
|
GrootConfig,
|
|
|
|
|
is_raw_groot_n1_7_checkpoint,
|
|
|
|
|
)
|
|
|
|
@@ -729,21 +734,36 @@ def make_groot_pre_post_processors(
|
|
|
|
|
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
|
|
|
|
|
# when it provides an image_target_size; otherwise fall back to the geometry the
|
|
|
|
|
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
|
|
|
|
|
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
|
|
|
|
|
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
|
|
|
|
|
# frames, inflating the VLM token count and feeding the model a resolution it was not trained on.
|
|
|
|
|
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
|
|
|
|
|
image_target_size = checkpoint_assets.image_target_size
|
|
|
|
|
image_crop_size = checkpoint_assets.image_crop_size
|
|
|
|
|
shortest_image_edge = checkpoint_assets.shortest_image_edge
|
|
|
|
|
crop_fraction = checkpoint_assets.crop_fraction
|
|
|
|
|
else:
|
|
|
|
|
image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE)
|
|
|
|
|
image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE)
|
|
|
|
|
shortest_image_edge = None
|
|
|
|
|
crop_fraction = None
|
|
|
|
|
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
|
|
|
|
|
|
|
|
|
|
input_steps: list[ProcessorStep] = [
|
|
|
|
|
RenameObservationsProcessorStep(rename_map={}),
|
|
|
|
|
AddBatchDimensionProcessorStep(),
|
|
|
|
|
pack_step,
|
|
|
|
|
GrootN17VLMEncodeStep(
|
|
|
|
|
model_name=config.n1_7_backbone_model,
|
|
|
|
|
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
|
|
|
|
|
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
|
|
|
|
|
shortest_image_edge=checkpoint_assets.shortest_image_edge
|
|
|
|
|
if checkpoint_assets is not None
|
|
|
|
|
else None,
|
|
|
|
|
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
|
|
|
|
|
use_albumentations=checkpoint_assets.use_albumentations
|
|
|
|
|
if checkpoint_assets is not None
|
|
|
|
|
else False,
|
|
|
|
|
image_crop_size=image_crop_size,
|
|
|
|
|
image_target_size=image_target_size,
|
|
|
|
|
shortest_image_edge=shortest_image_edge,
|
|
|
|
|
crop_fraction=crop_fraction,
|
|
|
|
|
use_albumentations=use_albumentations,
|
|
|
|
|
device=config.device,
|
|
|
|
|
),
|
|
|
|
|
DeviceProcessorStep(device=config.device),
|
|
|
|
|
]
|
|
|
|
@@ -899,15 +919,22 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
|
|
|
|
return proc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _transform_n1_7_image_for_vlm(
|
|
|
|
|
def _transform_n1_7_image_for_vlm_albumentations(
|
|
|
|
|
image: Image.Image,
|
|
|
|
|
*,
|
|
|
|
|
image_crop_size: list[int] | None,
|
|
|
|
|
image_target_size: list[int] | None,
|
|
|
|
|
shortest_image_edge: int | None,
|
|
|
|
|
crop_fraction: float | None,
|
|
|
|
|
use_albumentations: bool = False,
|
|
|
|
|
) -> Image.Image:
|
|
|
|
|
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
|
|
|
|
|
|
|
|
|
|
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
|
|
|
|
|
CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations)
|
|
|
|
|
geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The
|
|
|
|
|
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
|
|
|
|
|
torch path and must stay bit-exact to the upstream reference.
|
|
|
|
|
"""
|
|
|
|
|
if image_target_size is None:
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
@@ -915,70 +942,101 @@ def _transform_n1_7_image_for_vlm(
|
|
|
|
|
if image.mode != "RGB":
|
|
|
|
|
image = image.convert("RGB")
|
|
|
|
|
|
|
|
|
|
if use_albumentations:
|
|
|
|
|
try:
|
|
|
|
|
import cv2
|
|
|
|
|
except ImportError as exc:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
|
|
|
|
|
) from exc
|
|
|
|
|
try:
|
|
|
|
|
import cv2
|
|
|
|
|
except ImportError as exc:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
image_np = np.asarray(image)
|
|
|
|
|
height, width = image_np.shape[:2]
|
|
|
|
|
if height != width:
|
|
|
|
|
square_edge = max(height, width)
|
|
|
|
|
pad_h = square_edge - height
|
|
|
|
|
pad_w = square_edge - width
|
|
|
|
|
image_np = cv2.copyMakeBorder(
|
|
|
|
|
image_np,
|
|
|
|
|
pad_h // 2,
|
|
|
|
|
pad_h - pad_h // 2,
|
|
|
|
|
pad_w // 2,
|
|
|
|
|
pad_w - pad_w // 2,
|
|
|
|
|
cv2.BORDER_CONSTANT,
|
|
|
|
|
value=(0, 0, 0),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
resize_edge = shortest_image_edge or target_h
|
|
|
|
|
if image_np.shape[:2] != (resize_edge, resize_edge):
|
|
|
|
|
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
|
|
|
|
if crop_fraction is None and image_crop_size is not None:
|
|
|
|
|
crop_fraction = image_crop_size[0] / float(target_h)
|
|
|
|
|
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
|
|
|
|
height, width = image_np.shape[:2]
|
|
|
|
|
crop_h = max(1, int(height * crop_fraction))
|
|
|
|
|
crop_w = max(1, int(width * crop_fraction))
|
|
|
|
|
top = max(0, (height - crop_h) // 2)
|
|
|
|
|
left = max(0, (width - crop_w) // 2)
|
|
|
|
|
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
|
|
|
|
|
|
|
|
|
if image_np.shape[:2] != (target_h, target_w):
|
|
|
|
|
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
|
|
|
|
return Image.fromarray(image_np)
|
|
|
|
|
|
|
|
|
|
square_edge = max(image.width, image.height)
|
|
|
|
|
if image.width != image.height:
|
|
|
|
|
padded = Image.new("RGB", (square_edge, square_edge))
|
|
|
|
|
left = (square_edge - image.width) // 2
|
|
|
|
|
top = (square_edge - image.height) // 2
|
|
|
|
|
padded.paste(image, (left, top))
|
|
|
|
|
image = padded
|
|
|
|
|
image_np = np.asarray(image)
|
|
|
|
|
height, width = image_np.shape[:2]
|
|
|
|
|
if height != width:
|
|
|
|
|
square_edge = max(height, width)
|
|
|
|
|
pad_h = square_edge - height
|
|
|
|
|
pad_w = square_edge - width
|
|
|
|
|
image_np = cv2.copyMakeBorder(
|
|
|
|
|
image_np,
|
|
|
|
|
pad_h // 2,
|
|
|
|
|
pad_h - pad_h // 2,
|
|
|
|
|
pad_w // 2,
|
|
|
|
|
pad_w - pad_w // 2,
|
|
|
|
|
cv2.BORDER_CONSTANT,
|
|
|
|
|
value=(0, 0, 0),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
resize_edge = shortest_image_edge or target_h
|
|
|
|
|
image = image.resize((resize_edge, resize_edge), Image.Resampling.BICUBIC)
|
|
|
|
|
if image_np.shape[:2] != (resize_edge, resize_edge):
|
|
|
|
|
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
|
|
|
|
if crop_fraction is None and image_crop_size is not None:
|
|
|
|
|
crop_fraction = image_crop_size[0] / float(target_h)
|
|
|
|
|
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
|
|
|
|
crop_w = max(1, int(round(image.width * crop_fraction)))
|
|
|
|
|
crop_h = max(1, int(round(image.height * crop_fraction)))
|
|
|
|
|
left = max(0, (image.width - crop_w) // 2)
|
|
|
|
|
top = max(0, (image.height - crop_h) // 2)
|
|
|
|
|
image = image.crop((left, top, left + crop_w, top + crop_h))
|
|
|
|
|
height, width = image_np.shape[:2]
|
|
|
|
|
crop_h = max(1, int(height * crop_fraction))
|
|
|
|
|
crop_w = max(1, int(width * crop_fraction))
|
|
|
|
|
top = max(0, (height - crop_h) // 2)
|
|
|
|
|
left = max(0, (width - crop_w) // 2)
|
|
|
|
|
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
|
|
|
|
|
|
|
|
|
if image.size != (target_w, target_h):
|
|
|
|
|
image = image.resize((target_w, target_h), Image.Resampling.BICUBIC)
|
|
|
|
|
if image_np.shape[:2] != (target_h, target_w):
|
|
|
|
|
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
|
|
|
|
return Image.fromarray(image_np)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _transform_n1_7_image_for_vlm_torch(
|
|
|
|
|
image: torch.Tensor,
|
|
|
|
|
*,
|
|
|
|
|
image_crop_size: list[int] | None,
|
|
|
|
|
image_target_size: list[int] | None,
|
|
|
|
|
shortest_image_edge: int | None,
|
|
|
|
|
crop_fraction: float | None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
|
|
|
|
|
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
|
|
|
|
|
|
|
|
|
|
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
|
|
|
|
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
|
|
|
|
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
|
|
|
|
|
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
|
|
|
|
|
cv2/INTER_AREA path has no torch equivalent and stays on
|
|
|
|
|
:func:`_transform_n1_7_image_for_vlm_albumentations`.
|
|
|
|
|
"""
|
|
|
|
|
if image_target_size is None:
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
target_h, target_w = image_target_size
|
|
|
|
|
_, height, width = image.shape
|
|
|
|
|
|
|
|
|
|
square_edge = max(height, width)
|
|
|
|
|
if height != width:
|
|
|
|
|
left = (square_edge - width) // 2
|
|
|
|
|
top = (square_edge - height) // 2
|
|
|
|
|
image = tv_functional.pad(
|
|
|
|
|
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
resize_edge = shortest_image_edge or target_h
|
|
|
|
|
image = tv_functional.resize(
|
|
|
|
|
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if crop_fraction is None and image_crop_size is not None:
|
|
|
|
|
crop_fraction = image_crop_size[0] / float(target_h)
|
|
|
|
|
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
|
|
|
|
# Match the PIL helper's center crop exactly: round() the crop size but
|
|
|
|
|
# floor() the offset (torchvision.center_crop rounds the offset, which
|
|
|
|
|
# shifts the region by 1px when (edge - crop) is odd).
|
|
|
|
|
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
|
|
|
|
|
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
|
|
|
|
|
top = max(0, (image.shape[-2] - crop_h) // 2)
|
|
|
|
|
left = max(0, (image.shape[-1] - crop_w) // 2)
|
|
|
|
|
image = image[..., top : top + crop_h, left : left + crop_w]
|
|
|
|
|
|
|
|
|
|
if tuple(image.shape[-2:]) != (target_h, target_w):
|
|
|
|
|
image = tv_functional.resize(
|
|
|
|
|
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
|
|
|
|
|
)
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1280,6 +1338,12 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|
|
|
|
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
|
|
|
|
an image item in the same chat message so the resulting image tokens match
|
|
|
|
|
the temporal VLM packing used by Isaac-GR00T.
|
|
|
|
|
|
|
|
|
|
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
|
|
|
|
|
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
|
|
|
|
|
CUDA device, the resize/rescale/normalize/patchify run there. This keeps the
|
|
|
|
|
output bit-identical on CPU and moves the dominant preprocessing cost off
|
|
|
|
|
the critical path on GPU.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
|
|
|
@@ -1288,6 +1352,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|
|
|
|
shortest_image_edge: int | None = None
|
|
|
|
|
crop_fraction: float | None = None
|
|
|
|
|
use_albumentations: bool = False
|
|
|
|
|
device: str | None = None
|
|
|
|
|
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@@ -1296,6 +1361,69 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|
|
|
|
self._proc = _build_n1_7_processor(self.model_name)
|
|
|
|
|
return self._proc
|
|
|
|
|
|
|
|
|
|
def _target_device(self) -> torch.device | None:
|
|
|
|
|
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
|
|
|
|
|
if self.device is None or self.use_albumentations:
|
|
|
|
|
return None
|
|
|
|
|
try:
|
|
|
|
|
return get_safe_torch_device(self.device)
|
|
|
|
|
except (AssertionError, RuntimeError):
|
|
|
|
|
# A device serialized at train time (e.g. "cuda") may be unavailable
|
|
|
|
|
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
|
|
|
|
|
# this step is not in the standard device-override set. Fall back to
|
|
|
|
|
# the CPU path, which is bit-identical, instead of crashing.
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _build_sample_images(
|
|
|
|
|
self, video: Any, batch_size: int, target_device: torch.device | None
|
|
|
|
|
) -> list[list[Any]]:
|
|
|
|
|
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
|
|
|
|
|
|
|
|
|
|
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
|
|
|
|
|
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
|
|
|
|
|
``target_device`` when set) for the torchvision-backed Qwen processor.
|
|
|
|
|
"""
|
|
|
|
|
if self.use_albumentations:
|
|
|
|
|
video_np = np.asarray(video)
|
|
|
|
|
return [
|
|
|
|
|
[
|
|
|
|
|
_transform_n1_7_image_for_vlm_albumentations(
|
|
|
|
|
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
|
|
|
|
image_crop_size=self.image_crop_size,
|
|
|
|
|
image_target_size=self.image_target_size,
|
|
|
|
|
shortest_image_edge=self.shortest_image_edge,
|
|
|
|
|
crop_fraction=self.crop_fraction,
|
|
|
|
|
)
|
|
|
|
|
for timestep in range(video_np.shape[1])
|
|
|
|
|
for view_idx in range(video_np.shape[2])
|
|
|
|
|
]
|
|
|
|
|
for batch_idx in range(batch_size)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
|
|
|
|
|
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
|
|
|
|
|
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
|
|
|
|
|
if target_device is not None and video_t.device != target_device:
|
|
|
|
|
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
|
|
|
|
|
|
|
|
|
|
frames_per_sample: list[list[Any]] = []
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
sample = video_t[batch_idx] # (T, V, C, H, W)
|
|
|
|
|
frames_per_sample.append(
|
|
|
|
|
[
|
|
|
|
|
_transform_n1_7_image_for_vlm_torch(
|
|
|
|
|
sample[timestep, view_idx],
|
|
|
|
|
image_crop_size=self.image_crop_size,
|
|
|
|
|
image_target_size=self.image_target_size,
|
|
|
|
|
shortest_image_edge=self.shortest_image_edge,
|
|
|
|
|
crop_fraction=self.crop_fraction,
|
|
|
|
|
)
|
|
|
|
|
for timestep in range(sample.shape[0])
|
|
|
|
|
for view_idx in range(sample.shape[1])
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
return frames_per_sample
|
|
|
|
|
|
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
|
|
|
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
|
|
|
|
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
|
|
|
@@ -1303,33 +1431,25 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|
|
|
|
if video is None:
|
|
|
|
|
return transition
|
|
|
|
|
|
|
|
|
|
batch_size = int(video.shape[0])
|
|
|
|
|
languages = _prepare_n1_7_language_batch(
|
|
|
|
|
comp.get("language"),
|
|
|
|
|
video.shape[0],
|
|
|
|
|
batch_size,
|
|
|
|
|
formalize_language=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
target_device = self._target_device()
|
|
|
|
|
sample_images = self._build_sample_images(video, batch_size, target_device)
|
|
|
|
|
|
|
|
|
|
texts: list[str] = []
|
|
|
|
|
images: list[Image.Image] = []
|
|
|
|
|
for batch_idx in range(video.shape[0]):
|
|
|
|
|
sample = video[batch_idx] # (T, V, H, W, C)
|
|
|
|
|
sample_images = [
|
|
|
|
|
_transform_n1_7_image_for_vlm(
|
|
|
|
|
Image.fromarray(sample[timestep, view_idx]),
|
|
|
|
|
image_crop_size=self.image_crop_size,
|
|
|
|
|
image_target_size=self.image_target_size,
|
|
|
|
|
shortest_image_edge=self.shortest_image_edge,
|
|
|
|
|
crop_fraction=self.crop_fraction,
|
|
|
|
|
use_albumentations=self.use_albumentations,
|
|
|
|
|
)
|
|
|
|
|
for timestep in range(sample.shape[0])
|
|
|
|
|
for view_idx in range(sample.shape[1])
|
|
|
|
|
]
|
|
|
|
|
images: list[Any] = []
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
frames = sample_images[batch_idx]
|
|
|
|
|
conversation = [
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": [
|
|
|
|
|
*[{"type": "image", "image": image} for image in sample_images],
|
|
|
|
|
*[{"type": "image", "image": image} for image in frames],
|
|
|
|
|
{"type": "text", "text": languages[batch_idx]},
|
|
|
|
|
],
|
|
|
|
|
}
|
|
|
|
@@ -1341,9 +1461,17 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|
|
|
|
add_generation_prompt=False,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
images.extend(sample_images)
|
|
|
|
|
images.extend(frames)
|
|
|
|
|
|
|
|
|
|
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
|
proc_kwargs: dict[str, Any] = {
|
|
|
|
|
"text": texts,
|
|
|
|
|
"images": images,
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
"padding": True,
|
|
|
|
|
}
|
|
|
|
|
if target_device is not None:
|
|
|
|
|
proc_kwargs["device"] = str(target_device)
|
|
|
|
|
encoded = self.proc(**proc_kwargs)
|
|
|
|
|
for key, value in encoded.items():
|
|
|
|
|
comp[key] = value
|
|
|
|
|
obs.pop("video", None)
|
|
|
|
@@ -1362,6 +1490,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|
|
|
|
"shortest_image_edge": self.shortest_image_edge,
|
|
|
|
|
"crop_fraction": self.crop_fraction,
|
|
|
|
|
"use_albumentations": self.use_albumentations,
|
|
|
|
|
"device": self.device,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|