From b3e76a92f2e520bf15fd461fae165493846006d6 Mon Sep 17 00:00:00 2001 From: Matteo Tiezzi Date: Tue, 14 Apr 2026 13:09:18 +0200 Subject: [PATCH] fix(groot): compatibility fixes for gr00t in v0.5 (#3182) * fix(groot): apply groot 0.5 fixes * fix(groot): correct indentation and add tile count in Eagle25VL processor * Fixed lint7/style --- .../action_head/flow_matching_action_head.py | 8 ++++-- .../eagle2_hg_model/processing_eagle2_5_vl.py | 28 +++++++++++++++++-- src/lerobot/policies/groot/groot_n1.py | 1 + 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index 4fda21ca5..2c1ca6014 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module): self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) - self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) + self._noise_beta_alpha = config.noise_beta_alpha + self._noise_beta_beta = config.noise_beta_beta + self._beta_dist = None self.num_timestep_buckets = config.num_timestep_buckets self.config = config self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model) @@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module): self.model.eval() def sample_time(self, batch_size, device, dtype): - sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + if self._beta_dist is None: + self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False) + sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype) return (self.config.noise_s - sample) / self.config.noise_s def prepare_input(self, batch: dict) -> BatchFeature: diff --git a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py index 27f9b3345..7b1f67fef 100755 --- a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py +++ b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py @@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin): videos=None, **output_kwargs["images_kwargs"], ) + if isinstance(image_inputs["pixel_values"], list): + _pv = image_inputs["pixel_values"] + if _pv and isinstance(_pv[0], list): + _pv = [t for sub in _pv for t in sub] + image_inputs["pixel_values"] = torch.stack( + [t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv] + ) num_all_tiles = image_inputs["pixel_values"].shape[0] special_placeholder = f"{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}" unified_frame_list.append(image_inputs) @@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin): videos=[video_list[idx_in_list]], **output_kwargs["videos_kwargs"], ) + if isinstance(video_inputs["pixel_values"], list): + _pv = video_inputs["pixel_values"] + if _pv and isinstance(_pv[0], list): + _pv = [t for sub in _pv for t in sub] + video_inputs["pixel_values"] = torch.stack( + [t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv] + ) num_all_tiles = video_inputs["pixel_values"].shape[0] image_sizes = video_inputs["image_sizes"] if timestamps_list is not None and -1 not in timestamps_list: @@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin): text = replace_in_text(text) if len(unified_frame_list) > 0: - pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list]) - image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list]) + + def _to_tensor(v): + if isinstance(v, torch.Tensor): + return v + if isinstance(v, list): + if v and isinstance(v[0], list): + v = [t for sub in v for t in sub] + return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v]) + return torch.as_tensor(v) + + pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list]) + image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list]) else: pixel_values = None image_sizes = None diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index fc753839a..abcbb8a8c 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel): self.action_horizon = config.action_horizon self.action_dim = config.action_dim self.compute_dtype = config.compute_dtype + self.post_init() def validate_inputs(self, inputs): # NOTE -- this should be handled internally by the model