fix(groot): compatibility fixes for gr00t in v0.5 (#3182)

* fix(groot): apply groot 0.5 fixes

* fix(groot): correct indentation and add tile count in Eagle25VL processor

* Fixed lint7/style
This commit is contained in:
Matteo Tiezzi
2026-04-14 13:09:18 +02:00
committed by GitHub
parent f5c801fd34
commit b3e76a92f2
3 changed files with 33 additions and 4 deletions
@@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module):
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets self.num_timestep_buckets = config.num_timestep_buckets
self.config = config self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model) self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
@@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module):
self.model.eval() self.model.eval()
def sample_time(self, batch_size, device, dtype): def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) if self._beta_dist is None:
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature: def prepare_input(self, batch: dict) -> BatchFeature:
@@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=None, videos=None,
**output_kwargs["images_kwargs"], **output_kwargs["images_kwargs"],
) )
if isinstance(image_inputs["pixel_values"], list):
_pv = image_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
image_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = image_inputs["pixel_values"].shape[0] num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}" special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs) unified_frame_list.append(image_inputs)
@@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=[video_list[idx_in_list]], videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"], **output_kwargs["videos_kwargs"],
) )
if isinstance(video_inputs["pixel_values"], list):
_pv = video_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
video_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = video_inputs["pixel_values"].shape[0] num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"] image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list: if timestamps_list is not None and -1 not in timestamps_list:
@@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin):
text = replace_in_text(text) text = replace_in_text(text)
if len(unified_frame_list) > 0: if len(unified_frame_list) > 0:
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list]) def _to_tensor(v):
if isinstance(v, torch.Tensor):
return v
if isinstance(v, list):
if v and isinstance(v[0], list):
v = [t for sub in v for t in sub]
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
return torch.as_tensor(v)
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
else: else:
pixel_values = None pixel_values = None
image_sizes = None image_sizes = None
+1
View File
@@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel):
self.action_horizon = config.action_horizon self.action_horizon = config.action_horizon
self.action_dim = config.action_dim self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs): def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model # NOTE -- this should be handled internally by the model