mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(groot): compatibility fixes for gr00t in v0.5 (#3182)
* fix(groot): apply groot 0.5 fixes * fix(groot): correct indentation and add tile count in Eagle25VL processor * Fixed lint7/style
This commit is contained in:
@@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module):
|
|||||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
self._noise_beta_alpha = config.noise_beta_alpha
|
||||||
|
self._noise_beta_beta = config.noise_beta_beta
|
||||||
|
self._beta_dist = None
|
||||||
self.num_timestep_buckets = config.num_timestep_buckets
|
self.num_timestep_buckets = config.num_timestep_buckets
|
||||||
self.config = config
|
self.config = config
|
||||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||||
@@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def sample_time(self, batch_size, device, dtype):
|
def sample_time(self, batch_size, device, dtype):
|
||||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
if self._beta_dist is None:
|
||||||
|
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||||
|
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||||
return (self.config.noise_s - sample) / self.config.noise_s
|
return (self.config.noise_s - sample) / self.config.noise_s
|
||||||
|
|
||||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||||
|
|||||||
@@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
videos=None,
|
videos=None,
|
||||||
**output_kwargs["images_kwargs"],
|
**output_kwargs["images_kwargs"],
|
||||||
)
|
)
|
||||||
|
if isinstance(image_inputs["pixel_values"], list):
|
||||||
|
_pv = image_inputs["pixel_values"]
|
||||||
|
if _pv and isinstance(_pv[0], list):
|
||||||
|
_pv = [t for sub in _pv for t in sub]
|
||||||
|
image_inputs["pixel_values"] = torch.stack(
|
||||||
|
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||||
|
)
|
||||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||||
unified_frame_list.append(image_inputs)
|
unified_frame_list.append(image_inputs)
|
||||||
@@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
videos=[video_list[idx_in_list]],
|
videos=[video_list[idx_in_list]],
|
||||||
**output_kwargs["videos_kwargs"],
|
**output_kwargs["videos_kwargs"],
|
||||||
)
|
)
|
||||||
|
if isinstance(video_inputs["pixel_values"], list):
|
||||||
|
_pv = video_inputs["pixel_values"]
|
||||||
|
if _pv and isinstance(_pv[0], list):
|
||||||
|
_pv = [t for sub in _pv for t in sub]
|
||||||
|
video_inputs["pixel_values"] = torch.stack(
|
||||||
|
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||||
|
)
|
||||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||||
image_sizes = video_inputs["image_sizes"]
|
image_sizes = video_inputs["image_sizes"]
|
||||||
if timestamps_list is not None and -1 not in timestamps_list:
|
if timestamps_list is not None and -1 not in timestamps_list:
|
||||||
@@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
text = replace_in_text(text)
|
text = replace_in_text(text)
|
||||||
if len(unified_frame_list) > 0:
|
if len(unified_frame_list) > 0:
|
||||||
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
|
|
||||||
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
|
def _to_tensor(v):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
return v
|
||||||
|
if isinstance(v, list):
|
||||||
|
if v and isinstance(v[0], list):
|
||||||
|
v = [t for sub in v for t in sub]
|
||||||
|
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||||
|
return torch.as_tensor(v)
|
||||||
|
|
||||||
|
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||||
|
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||||
else:
|
else:
|
||||||
pixel_values = None
|
pixel_values = None
|
||||||
image_sizes = None
|
image_sizes = None
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel):
|
|||||||
self.action_horizon = config.action_horizon
|
self.action_horizon = config.action_horizon
|
||||||
self.action_dim = config.action_dim
|
self.action_dim = config.action_dim
|
||||||
self.compute_dtype = config.compute_dtype
|
self.compute_dtype = config.compute_dtype
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def validate_inputs(self, inputs):
|
def validate_inputs(self, inputs):
|
||||||
# NOTE -- this should be handled internally by the model
|
# NOTE -- this should be handled internally by the model
|
||||||
|
|||||||
Reference in New Issue
Block a user