Merge branch 'main' into feat/decouple_record_script

This commit is contained in:
Steven Palma
2026-04-14 21:35:26 +02:00
committed by GitHub
7 changed files with 98 additions and 76 deletions
+2 -2
View File
@@ -30,13 +30,13 @@ def safe_stop_image_writer(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except Exception as e: except BaseException:
dataset = kwargs.get("dataset") dataset = kwargs.get("dataset")
writer = getattr(dataset, "writer", None) if dataset else None writer = getattr(dataset, "writer", None) if dataset else None
if writer is not None and writer.image_writer is not None: if writer is not None and writer.image_writer is not None:
logger.warning("Waiting for image writer to terminate...") logger.warning("Waiting for image writer to terminate...")
writer.image_writer.stop() writer.image_writer.stop()
raise e raise
return wrapper return wrapper
@@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module):
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets self.num_timestep_buckets = config.num_timestep_buckets
self.config = config self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model) self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
@@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module):
self.model.eval() self.model.eval()
def sample_time(self, batch_size, device, dtype): def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) if self._beta_dist is None:
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature: def prepare_input(self, batch: dict) -> BatchFeature:
@@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=None, videos=None,
**output_kwargs["images_kwargs"], **output_kwargs["images_kwargs"],
) )
if isinstance(image_inputs["pixel_values"], list):
_pv = image_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
image_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = image_inputs["pixel_values"].shape[0] num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}" special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs) unified_frame_list.append(image_inputs)
@@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=[video_list[idx_in_list]], videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"], **output_kwargs["videos_kwargs"],
) )
if isinstance(video_inputs["pixel_values"], list):
_pv = video_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
video_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = video_inputs["pixel_values"].shape[0] num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"] image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list: if timestamps_list is not None and -1 not in timestamps_list:
@@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin):
text = replace_in_text(text) text = replace_in_text(text)
if len(unified_frame_list) > 0: if len(unified_frame_list) > 0:
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list]) def _to_tensor(v):
if isinstance(v, torch.Tensor):
return v
if isinstance(v, list):
if v and isinstance(v[0], list):
v = [t for sub in v for t in sub]
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
return torch.as_tensor(v)
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
else: else:
pixel_values = None pixel_values = None
image_sizes = None image_sizes = None
+1
View File
@@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel):
self.action_horizon = config.action_horizon self.action_horizon = config.action_horizon
self.action_dim = config.action_dim self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs): def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model # NOTE -- this should be handled internally by the model
+54 -58
View File
@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import functools import functools
import threading
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import suppress from contextlib import suppress
from typing import TypedDict from typing import TypedDict
@@ -115,6 +116,7 @@ class ReplayBuffer:
self.size = 0 self.size = 0
self.initialized = False self.initialized = False
self.optimize_memory = optimize_memory self.optimize_memory = optimize_memory
self._lock = threading.Lock()
# Track episode boundaries for memory optimization # Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
@@ -198,68 +200,75 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor] | None = None,
): ):
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition with self._lock:
if not self.initialized: # Initialize storage if this is the first transition
self._initialize_storage(state=state, action=action, complementary_info=complementary_info) if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors # Store the transition in pre-allocated tensors
for key in self.states: for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0)) self.states[key][self.position].copy_(state[key].squeeze(dim=0))
if not self.optimize_memory: if not self.optimize_memory:
# Only store next_states if not optimizing memory # Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward self.rewards[self.position] = reward
self.dones[self.position] = done self.dones[self.position] = done
self.truncateds[self.position] = truncated self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized # Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info: if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info for key in self.complementary_info_keys:
for key in self.complementary_info_keys: if key in complementary_info:
if key in complementary_info: value = complementary_info[key]
value = complementary_info[key] if isinstance(value, torch.Tensor):
if isinstance(value, torch.Tensor): self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) elif isinstance(value, (int | float)):
elif isinstance(value, (int | float)): self.complementary_info[key][self.position] = value
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity) self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition: def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors.""" """Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized: if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size) with self._lock:
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state batch_state = {}
batch_state = {} batch_next_state = {}
batch_next_state = {}
# First pass: load all state tensors to target device for key in self.states:
for key in self.states: batch_state[key] = self.states[key][idx].to(self.device)
batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory: if not self.optimize_memory:
# Standard approach - load next_states directly batch_next_state[key] = self.next_states[key][idx].to(self.device)
batch_next_state[key] = self.next_states[key][idx].to(self.device) else:
else: next_idx = (idx + 1) % self.capacity
# Memory-optimized approach - get next_state from the next index batch_next_state[key] = self.states[key][next_idx].to(self.device)
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device) # Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys: if self.use_drq and image_keys:
# Concatenate all images from state and next_state # Concatenate all images from state and next_state
all_images = [] all_images = []
@@ -280,19 +289,6 @@ class ReplayBuffer:
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition( return BatchTransition(
state=batch_state, state=batch_state,
action=batch_actions, action=batch_actions,
+2 -2
View File
@@ -551,8 +551,8 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE] terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
new_info = processed_action_transition[TransitionKey.INFO].copy() new_info = info.copy()
new_info.update(info) new_info.update(processed_action_transition[TransitionKey.INFO])
new_transition = create_transition( new_transition = create_transition(
observation=obs, observation=obs,
@@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
) )
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train() policy.train()
# Use preprocessor to handle tokenization # Use preprocessor to handle tokenization
@@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
) )
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.eval() policy.eval()
policy.reset() # Reset queues before inference policy.reset() # Reset queues before inference
@@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective():
config.validate_features() config.validate_features()
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train() policy.train()
# Use preprocessor to handle tokenization # Use preprocessor to handle tokenization
@@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective():
config.validate_features() config.validate_features()
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train() policy.train()
# Use preprocessor to handle tokenization # Use preprocessor to handle tokenization
@@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
) )
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.eval() policy.eval()
# Get device before saving
device = next(policy.parameters()).device
policy.save_pretrained(root) policy.save_pretrained(root)
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config) loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
loaded_policy.to(config.device)
# Explicitly move loaded_policy to the same device
loaded_policy.to(device)
loaded_policy.eval() loaded_policy.eval()
batch = create_train_batch( batch = create_train_batch(
@@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
with seeded_context(12): with seeded_context(12):
# Process batch through preprocessor # Process batch through preprocessor
processed_batch = preprocessor(batch) processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving # Collect policy values before saving
loss, _ = policy.forward(processed_batch) loss, _ = policy.forward(processed_batch)
@@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params():
) )
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
param_groups = policy.get_optim_params() param_groups = policy.get_optim_params()
# Should have 2 parameter groups: non-vision and vision encoder # Should have 2 parameter groups: non-vision and vision encoder