diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 603067757..8fb5804a5 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -30,13 +30,13 @@ def safe_stop_image_writer(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) - except Exception as e: + except BaseException: dataset = kwargs.get("dataset") writer = getattr(dataset, "writer", None) if dataset else None if writer is not None and writer.image_writer is not None: logger.warning("Waiting for image writer to terminate...") writer.image_writer.stop() - raise e + raise return wrapper diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index 4fda21ca5..2c1ca6014 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module): self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) - self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) + self._noise_beta_alpha = config.noise_beta_alpha + self._noise_beta_beta = config.noise_beta_beta + self._beta_dist = None self.num_timestep_buckets = config.num_timestep_buckets self.config = config self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model) @@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module): self.model.eval() def sample_time(self, batch_size, device, dtype): - sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + if self._beta_dist is None: + self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False) + sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype) return (self.config.noise_s - sample) / self.config.noise_s def prepare_input(self, batch: dict) -> BatchFeature: diff --git a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py index 27f9b3345..7b1f67fef 100755 --- a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py +++ b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py @@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin): videos=None, **output_kwargs["images_kwargs"], ) + if isinstance(image_inputs["pixel_values"], list): + _pv = image_inputs["pixel_values"] + if _pv and isinstance(_pv[0], list): + _pv = [t for sub in _pv for t in sub] + image_inputs["pixel_values"] = torch.stack( + [t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv] + ) num_all_tiles = image_inputs["pixel_values"].shape[0] special_placeholder = f"{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}" unified_frame_list.append(image_inputs) @@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin): videos=[video_list[idx_in_list]], **output_kwargs["videos_kwargs"], ) + if isinstance(video_inputs["pixel_values"], list): + _pv = video_inputs["pixel_values"] + if _pv and isinstance(_pv[0], list): + _pv = [t for sub in _pv for t in sub] + video_inputs["pixel_values"] = torch.stack( + [t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv] + ) num_all_tiles = video_inputs["pixel_values"].shape[0] image_sizes = video_inputs["image_sizes"] if timestamps_list is not None and -1 not in timestamps_list: @@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin): text = replace_in_text(text) if len(unified_frame_list) > 0: - pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list]) - image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list]) + + def _to_tensor(v): + if isinstance(v, torch.Tensor): + return v + if isinstance(v, list): + if v and isinstance(v[0], list): + v = [t for sub in v for t in sub] + return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v]) + return torch.as_tensor(v) + + pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list]) + image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list]) else: pixel_values = None image_sizes = None diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index fc753839a..abcbb8a8c 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel): self.action_horizon = config.action_horizon self.action_dim = config.action_dim self.compute_dtype = config.compute_dtype + self.post_init() def validate_inputs(self, inputs): # NOTE -- this should be handled internally by the model diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 97aaa9caa..05b8419bd 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -15,6 +15,7 @@ # limitations under the License. import functools +import threading from collections.abc import Callable, Sequence from contextlib import suppress from typing import TypedDict @@ -115,6 +116,7 @@ class ReplayBuffer: self.size = 0 self.initialized = False self.optimize_memory = optimize_memory + self._lock = threading.Lock() # Track episode boundaries for memory optimization self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) @@ -198,68 +200,75 @@ class ReplayBuffer: complementary_info: dict[str, torch.Tensor] | None = None, ): """Saves a transition, ensuring tensors are stored on the designated storage device.""" - # Initialize storage if this is the first transition - if not self.initialized: - self._initialize_storage(state=state, action=action, complementary_info=complementary_info) + with self._lock: + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) - # Store the transition in pre-allocated tensors - for key in self.states: - self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) - if not self.optimize_memory: - # Only store next_states if not optimizing memory - self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) - self.actions[self.position].copy_(action.squeeze(dim=0)) - self.rewards[self.position] = reward - self.dones[self.position] = done - self.truncateds[self.position] = truncated + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated - # Handle complementary_info if provided and storage is initialized - if complementary_info is not None and self.has_complementary_info: - # Store the complementary_info - for key in self.complementary_info_keys: - if key in complementary_info: - value = complementary_info[key] - if isinstance(value, torch.Tensor): - self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) - elif isinstance(value, (int | float)): - self.complementary_info[key][self.position] = value + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int | float)): + self.complementary_info[key][self.position] = value - self.position = (self.position + 1) % self.capacity - self.size = min(self.size + 1, self.capacity) + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) def sample(self, batch_size: int) -> BatchTransition: """Sample a random batch of transitions and collate them into batched tensors.""" if not self.initialized: raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") - batch_size = min(batch_size, self.size) - high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size + with self._lock: + batch_size = min(batch_size, self.size) + high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size - # Random indices for sampling - create on the same device as storage - idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) + idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) - # Identify image keys that need augmentation - image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] + image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] - # Create batched state and next_state - batch_state = {} - batch_next_state = {} + batch_state = {} + batch_next_state = {} - # First pass: load all state tensors to target device - for key in self.states: - batch_state[key] = self.states[key][idx].to(self.device) + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) - if not self.optimize_memory: - # Standard approach - load next_states directly - batch_next_state[key] = self.next_states[key][idx].to(self.device) - else: - # Memory-optimized approach - get next_state from the next index - next_idx = (idx + 1) % self.capacity - batch_next_state[key] = self.states[key][next_idx].to(self.device) + if not self.optimize_memory: + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) - # Apply image augmentation in a batched way if needed if self.use_drq and image_keys: # Concatenate all images from state and next_state all_images = [] @@ -280,19 +289,6 @@ class ReplayBuffer: # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] - # Sample other tensors - batch_actions = self.actions[idx].to(self.device) - batch_rewards = self.rewards[idx].to(self.device) - batch_dones = self.dones[idx].to(self.device).float() - batch_truncateds = self.truncateds[idx].to(self.device).float() - - # Sample complementary_info if available - batch_complementary_info = None - if self.has_complementary_info: - batch_complementary_info = {} - for key in self.complementary_info_keys: - batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) - return BatchTransition( state=batch_state, action=batch_actions, diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index b6ff7155a..2190070f5 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -551,8 +551,8 @@ def step_env_and_process_transition( terminated = terminated or processed_action_transition[TransitionKey.DONE] truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() - new_info = processed_action_transition[TransitionKey.INFO].copy() - new_info.update(info) + new_info = info.copy() + new_info.update(processed_action_transition[TransitionKey.INFO]) new_transition = create_transition( observation=obs, diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py index 5b70422d4..e4d456d19 100644 --- a/tests/policies/multi_task_dit/test_multi_task_dit.py +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.train() # Use preprocessor to handle tokenization @@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.eval() policy.reset() # Reset queues before inference @@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective(): config.validate_features() policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.train() # Use preprocessor to handle tokenization @@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective(): config.validate_features() policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.train() # Use preprocessor to handle tokenization @@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path): ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.eval() - # Get device before saving - device = next(policy.parameters()).device - policy.save_pretrained(root) loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config) - - # Explicitly move loaded_policy to the same device - loaded_policy.to(device) + loaded_policy.to(config.device) loaded_policy.eval() batch = create_train_batch( @@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path): with seeded_context(12): # Process batch through preprocessor processed_batch = preprocessor(batch) - # Move batch to the same device as the policy - for key in processed_batch: - if isinstance(processed_batch[key], torch.Tensor): - processed_batch[key] = processed_batch[key].to(device) # Collect policy values before saving loss, _ = policy.forward(processed_batch) @@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params(): ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) param_groups = policy.get_optim_params() # Should have 2 parameter groups: non-vision and vision encoder