mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
Merge branch 'main' into feat/decouple_record_script
This commit is contained in:
@@ -30,13 +30,13 @@ def safe_stop_image_writer(func):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception as e:
|
except BaseException:
|
||||||
dataset = kwargs.get("dataset")
|
dataset = kwargs.get("dataset")
|
||||||
writer = getattr(dataset, "writer", None) if dataset else None
|
writer = getattr(dataset, "writer", None) if dataset else None
|
||||||
if writer is not None and writer.image_writer is not None:
|
if writer is not None and writer.image_writer is not None:
|
||||||
logger.warning("Waiting for image writer to terminate...")
|
logger.warning("Waiting for image writer to terminate...")
|
||||||
writer.image_writer.stop()
|
writer.image_writer.stop()
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|||||||
@@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module):
|
|||||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
self._noise_beta_alpha = config.noise_beta_alpha
|
||||||
|
self._noise_beta_beta = config.noise_beta_beta
|
||||||
|
self._beta_dist = None
|
||||||
self.num_timestep_buckets = config.num_timestep_buckets
|
self.num_timestep_buckets = config.num_timestep_buckets
|
||||||
self.config = config
|
self.config = config
|
||||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||||
@@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def sample_time(self, batch_size, device, dtype):
|
def sample_time(self, batch_size, device, dtype):
|
||||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
if self._beta_dist is None:
|
||||||
|
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||||
|
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||||
return (self.config.noise_s - sample) / self.config.noise_s
|
return (self.config.noise_s - sample) / self.config.noise_s
|
||||||
|
|
||||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||||
|
|||||||
@@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
videos=None,
|
videos=None,
|
||||||
**output_kwargs["images_kwargs"],
|
**output_kwargs["images_kwargs"],
|
||||||
)
|
)
|
||||||
|
if isinstance(image_inputs["pixel_values"], list):
|
||||||
|
_pv = image_inputs["pixel_values"]
|
||||||
|
if _pv and isinstance(_pv[0], list):
|
||||||
|
_pv = [t for sub in _pv for t in sub]
|
||||||
|
image_inputs["pixel_values"] = torch.stack(
|
||||||
|
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||||
|
)
|
||||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||||
unified_frame_list.append(image_inputs)
|
unified_frame_list.append(image_inputs)
|
||||||
@@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
videos=[video_list[idx_in_list]],
|
videos=[video_list[idx_in_list]],
|
||||||
**output_kwargs["videos_kwargs"],
|
**output_kwargs["videos_kwargs"],
|
||||||
)
|
)
|
||||||
|
if isinstance(video_inputs["pixel_values"], list):
|
||||||
|
_pv = video_inputs["pixel_values"]
|
||||||
|
if _pv and isinstance(_pv[0], list):
|
||||||
|
_pv = [t for sub in _pv for t in sub]
|
||||||
|
video_inputs["pixel_values"] = torch.stack(
|
||||||
|
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||||
|
)
|
||||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||||
image_sizes = video_inputs["image_sizes"]
|
image_sizes = video_inputs["image_sizes"]
|
||||||
if timestamps_list is not None and -1 not in timestamps_list:
|
if timestamps_list is not None and -1 not in timestamps_list:
|
||||||
@@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
text = replace_in_text(text)
|
text = replace_in_text(text)
|
||||||
if len(unified_frame_list) > 0:
|
if len(unified_frame_list) > 0:
|
||||||
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
|
|
||||||
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
|
def _to_tensor(v):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
return v
|
||||||
|
if isinstance(v, list):
|
||||||
|
if v and isinstance(v[0], list):
|
||||||
|
v = [t for sub in v for t in sub]
|
||||||
|
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||||
|
return torch.as_tensor(v)
|
||||||
|
|
||||||
|
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||||
|
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||||
else:
|
else:
|
||||||
pixel_values = None
|
pixel_values = None
|
||||||
image_sizes = None
|
image_sizes = None
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel):
|
|||||||
self.action_horizon = config.action_horizon
|
self.action_horizon = config.action_horizon
|
||||||
self.action_dim = config.action_dim
|
self.action_dim = config.action_dim
|
||||||
self.compute_dtype = config.compute_dtype
|
self.compute_dtype = config.compute_dtype
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def validate_inputs(self, inputs):
|
def validate_inputs(self, inputs):
|
||||||
# NOTE -- this should be handled internally by the model
|
# NOTE -- this should be handled internally by the model
|
||||||
|
|||||||
+54
-58
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
@@ -115,6 +116,7 @@ class ReplayBuffer:
|
|||||||
self.size = 0
|
self.size = 0
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
self.optimize_memory = optimize_memory
|
self.optimize_memory = optimize_memory
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
# Track episode boundaries for memory optimization
|
# Track episode boundaries for memory optimization
|
||||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||||
@@ -198,68 +200,75 @@ class ReplayBuffer:
|
|||||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||||
):
|
):
|
||||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
# Initialize storage if this is the first transition
|
with self._lock:
|
||||||
if not self.initialized:
|
# Initialize storage if this is the first transition
|
||||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
if not self.initialized:
|
||||||
|
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||||
|
|
||||||
# Store the transition in pre-allocated tensors
|
# Store the transition in pre-allocated tensors
|
||||||
for key in self.states:
|
for key in self.states:
|
||||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||||
|
|
||||||
if not self.optimize_memory:
|
if not self.optimize_memory:
|
||||||
# Only store next_states if not optimizing memory
|
# Only store next_states if not optimizing memory
|
||||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||||
|
|
||||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||||
self.rewards[self.position] = reward
|
self.rewards[self.position] = reward
|
||||||
self.dones[self.position] = done
|
self.dones[self.position] = done
|
||||||
self.truncateds[self.position] = truncated
|
self.truncateds[self.position] = truncated
|
||||||
|
|
||||||
# Handle complementary_info if provided and storage is initialized
|
# Handle complementary_info if provided and storage is initialized
|
||||||
if complementary_info is not None and self.has_complementary_info:
|
if complementary_info is not None and self.has_complementary_info:
|
||||||
# Store the complementary_info
|
for key in self.complementary_info_keys:
|
||||||
for key in self.complementary_info_keys:
|
if key in complementary_info:
|
||||||
if key in complementary_info:
|
value = complementary_info[key]
|
||||||
value = complementary_info[key]
|
if isinstance(value, torch.Tensor):
|
||||||
if isinstance(value, torch.Tensor):
|
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
elif isinstance(value, (int | float)):
|
||||||
elif isinstance(value, (int | float)):
|
self.complementary_info[key][self.position] = value
|
||||||
self.complementary_info[key][self.position] = value
|
|
||||||
|
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
self.size = min(self.size + 1, self.capacity)
|
self.size = min(self.size + 1, self.capacity)
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> BatchTransition:
|
def sample(self, batch_size: int) -> BatchTransition:
|
||||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||||
|
|
||||||
batch_size = min(batch_size, self.size)
|
with self._lock:
|
||||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
batch_size = min(batch_size, self.size)
|
||||||
|
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||||
|
|
||||||
# Random indices for sampling - create on the same device as storage
|
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
|
||||||
|
|
||||||
# Identify image keys that need augmentation
|
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
|
||||||
|
|
||||||
# Create batched state and next_state
|
batch_state = {}
|
||||||
batch_state = {}
|
batch_next_state = {}
|
||||||
batch_next_state = {}
|
|
||||||
|
|
||||||
# First pass: load all state tensors to target device
|
for key in self.states:
|
||||||
for key in self.states:
|
batch_state[key] = self.states[key][idx].to(self.device)
|
||||||
batch_state[key] = self.states[key][idx].to(self.device)
|
|
||||||
|
|
||||||
if not self.optimize_memory:
|
if not self.optimize_memory:
|
||||||
# Standard approach - load next_states directly
|
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
else:
|
||||||
else:
|
next_idx = (idx + 1) % self.capacity
|
||||||
# Memory-optimized approach - get next_state from the next index
|
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||||
next_idx = (idx + 1) % self.capacity
|
|
||||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
# Sample other tensors
|
||||||
|
batch_actions = self.actions[idx].to(self.device)
|
||||||
|
batch_rewards = self.rewards[idx].to(self.device)
|
||||||
|
batch_dones = self.dones[idx].to(self.device).float()
|
||||||
|
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||||
|
|
||||||
|
# Sample complementary_info if available
|
||||||
|
batch_complementary_info = None
|
||||||
|
if self.has_complementary_info:
|
||||||
|
batch_complementary_info = {}
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||||
|
|
||||||
# Apply image augmentation in a batched way if needed
|
|
||||||
if self.use_drq and image_keys:
|
if self.use_drq and image_keys:
|
||||||
# Concatenate all images from state and next_state
|
# Concatenate all images from state and next_state
|
||||||
all_images = []
|
all_images = []
|
||||||
@@ -280,19 +289,6 @@ class ReplayBuffer:
|
|||||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||||
|
|
||||||
# Sample other tensors
|
|
||||||
batch_actions = self.actions[idx].to(self.device)
|
|
||||||
batch_rewards = self.rewards[idx].to(self.device)
|
|
||||||
batch_dones = self.dones[idx].to(self.device).float()
|
|
||||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
|
||||||
|
|
||||||
# Sample complementary_info if available
|
|
||||||
batch_complementary_info = None
|
|
||||||
if self.has_complementary_info:
|
|
||||||
batch_complementary_info = {}
|
|
||||||
for key in self.complementary_info_keys:
|
|
||||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
|
||||||
|
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
state=batch_state,
|
state=batch_state,
|
||||||
action=batch_actions,
|
action=batch_actions,
|
||||||
|
|||||||
@@ -551,8 +551,8 @@ def step_env_and_process_transition(
|
|||||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||||
new_info = processed_action_transition[TransitionKey.INFO].copy()
|
new_info = info.copy()
|
||||||
new_info.update(info)
|
new_info.update(processed_action_transition[TransitionKey.INFO])
|
||||||
|
|
||||||
new_transition = create_transition(
|
new_transition = create_transition(
|
||||||
observation=obs,
|
observation=obs,
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Use preprocessor to handle tokenization
|
# Use preprocessor to handle tokenization
|
||||||
@@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
policy.reset() # Reset queues before inference
|
policy.reset() # Reset queues before inference
|
||||||
|
|
||||||
@@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective():
|
|||||||
config.validate_features()
|
config.validate_features()
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Use preprocessor to handle tokenization
|
# Use preprocessor to handle tokenization
|
||||||
@@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
|||||||
config.validate_features()
|
config.validate_features()
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Use preprocessor to handle tokenization
|
# Use preprocessor to handle tokenization
|
||||||
@@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
# Get device before saving
|
|
||||||
device = next(policy.parameters()).device
|
|
||||||
|
|
||||||
policy.save_pretrained(root)
|
policy.save_pretrained(root)
|
||||||
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
||||||
|
loaded_policy.to(config.device)
|
||||||
# Explicitly move loaded_policy to the same device
|
|
||||||
loaded_policy.to(device)
|
|
||||||
loaded_policy.eval()
|
loaded_policy.eval()
|
||||||
|
|
||||||
batch = create_train_batch(
|
batch = create_train_batch(
|
||||||
@@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
|||||||
with seeded_context(12):
|
with seeded_context(12):
|
||||||
# Process batch through preprocessor
|
# Process batch through preprocessor
|
||||||
processed_batch = preprocessor(batch)
|
processed_batch = preprocessor(batch)
|
||||||
# Move batch to the same device as the policy
|
|
||||||
for key in processed_batch:
|
|
||||||
if isinstance(processed_batch[key], torch.Tensor):
|
|
||||||
processed_batch[key] = processed_batch[key].to(device)
|
|
||||||
# Collect policy values before saving
|
# Collect policy values before saving
|
||||||
loss, _ = policy.forward(processed_batch)
|
loss, _ = policy.forward(processed_batch)
|
||||||
|
|
||||||
@@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params():
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
param_groups = policy.get_optim_params()
|
param_groups = policy.get_optim_params()
|
||||||
|
|
||||||
# Should have 2 parameter groups: non-vision and vision encoder
|
# Should have 2 parameter groups: non-vision and vision encoder
|
||||||
|
|||||||
Reference in New Issue
Block a user