This commit is contained in:
Michel Aractingi
2025-10-30 14:17:36 +01:00
parent 1594ae60a7
commit e5e1c97a6c
3 changed files with 123 additions and 11 deletions
@@ -139,6 +139,8 @@ class DSRLConfig(PreTrainedConfig):
# Training parameter # Training parameter
# Number of steps for online training # Number of steps for online training
online_steps: int = 1000000 online_steps: int = 1000000
# Number of steps for offline training
offline_steps: int = 100000
# Capacity of the online replay buffer # Capacity of the online replay buffer
online_buffer_capacity: int = 100000 online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer # Capacity of the offline replay buffer
+111 -6
View File
@@ -33,7 +33,13 @@ from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig, is_image_featur
from lerobot.policies.factory import get_policy_class from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE from lerobot.utils.constants import (
ACTION,
OBS_ENV_STATE,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
VALID_ACTION_POLICIES = ["diffusion", "smolvla", "pi0", "pi05"] VALID_ACTION_POLICIES = ["diffusion", "smolvla", "pi0", "pi05"]
@@ -132,7 +138,9 @@ class DSRLPolicy(PreTrainedPolicy):
noise, _, _ = self.noise_actor(batch, observations_features) noise, _, _ = self.noise_actor(batch, observations_features)
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1) noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
actions = self.action_policy.predict_action_chunk(batch, noise=noise) # Ensure observations have language tokens if action policy requires them
batch_with_lang = self._add_language_tokens_if_needed(batch)
actions = self.action_policy.predict_action_chunk(batch_with_lang, noise=noise)
return actions[:, 0, :] return actions[:, 0, :]
def action_critic_forward( def action_critic_forward(
@@ -295,7 +303,9 @@ class DSRLPolicy(PreTrainedPolicy):
# Generate next actions # Generate next actions
# a' = πW_dp(s', w') # a' = πW_dp(s', w')
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise) # Ensure next_observations have language tokens if action policy requires them
next_obs_with_lang = self._add_language_tokens_if_needed(next_observations)
next_actions_chunk = self.action_policy.predict_action_chunk(next_obs_with_lang, next_noise)
next_action_preds = next_actions_chunk[:, 0, :] next_action_preds = next_actions_chunk[:, 0, :]
# Compute target Q-values: Q̄A(s', a') # Compute target Q-values: Q̄A(s', a')
@@ -370,8 +380,10 @@ class DSRLPolicy(PreTrainedPolicy):
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self)) noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1) noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
with torch.no_grad(): with torch.no_grad():
# Ensure observations have language tokens if action policy requires them
obs_with_lang = self._add_language_tokens_if_needed(observations)
# Generate action using base policy: a = πW_dp(s, w) # Generate action using base policy: a = πW_dp(s, w)
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise) actions_chunk = self.action_policy.predict_action_chunk(obs_with_lang, noise=noise)
actions = actions_chunk[:, 0, :] actions = actions_chunk[:, 0, :]
# Get target Q-values from action critic: QA(s, a) # Get target Q-values from action critic: QA(s, a)
@@ -443,7 +455,12 @@ class DSRLPolicy(PreTrainedPolicy):
return noise_actor_loss return noise_actor_loss
def _init_action_policy(self): def _init_action_policy(self):
"""Initialize the action policy.""" """Initialize the action policy and freeze it completely.
The action policy is a pretrained model that should never be updated
during DSRL training. All parameters are frozen (requires_grad=False)
and the model is set to eval mode.
"""
action_policy = get_policy_class(self.config.action_policy_name) action_policy = get_policy_class(self.config.action_policy_name)
if self.config.action_policy_weights is not None: if self.config.action_policy_weights is not None:
@@ -452,7 +469,90 @@ class DSRLPolicy(PreTrainedPolicy):
self.action_policy = action_policy(self.config) self.action_policy = action_policy(self.config)
self.action_policy.to(self.config.device) self.action_policy.to(self.config.device)
self.action_policy.eval() self.action_policy.eval()
# Freeze all parameters - action policy should never be updated
for param in self.action_policy.parameters():
param.requires_grad = False
# Cache for lazy preprocessor creation
self._action_policy_preprocessor = None
def _get_action_policy_preprocessor(self):
"""Lazily create and cache the tokenizer for the action policy.
The tokenizer is created on-demand when first needed to tokenize task descriptions
for PI0/PI05 policies during training. It's cached for efficiency.
Note: This caches just the tokenizer, not a full preprocessor pipeline, because
replay buffer observations don't have complementary task data that the preprocessor
would need.
"""
if self._action_policy_preprocessor is not None:
return self._action_policy_preprocessor
# Only create tokenizer for PI0/PI05 policies
if self.config.action_policy_name not in ["pi0", "pi05"]:
return None
# Import here to avoid circular imports
from transformers import AutoTokenizer
# Cache the tokenizer for the action policy
# Using the same tokenizer as PI0 uses (following processor_pi0.py)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self._action_policy_preprocessor = tokenizer
return tokenizer
def _add_language_tokens_if_needed(self, observations: dict[str, Tensor]) -> dict[str, Tensor]:
"""Add language tokens to observations if missing and action policy requires them.
For PI0/PI05 policies, adds default language tokens using the same tokenizer
as the action policy. This is used during training when replay buffer observations
don't have complementary task data.
Args:
observations: Dictionary of observation tensors
Returns:
Observations dict with language tokens added if needed
"""
tokenizer = self._get_action_policy_preprocessor()
# If language tokens already exist, return as-is
if OBS_LANGUAGE_TOKENS in observations and OBS_LANGUAGE_ATTENTION_MASK in observations:
return observations
# Get batch properties
device = next(iter(observations.values())).device
batch_size = next(iter(observations.values())).shape[0]
# Get tokenizer max length from action policy config
tokenizer_max_length = getattr(self.action_policy.config, "tokenizer_max_length", 48)
# Use default task description with newline (as PI0 expects)
default_task = "Pick up the object\n"
tasks = [default_task] * batch_size
# Tokenize tasks using the action policy's tokenizer
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Add to observations dict
observations = observations.copy()
observations[OBS_LANGUAGE_TOKENS] = lang_tokens
observations[OBS_LANGUAGE_ATTENTION_MASK] = lang_masks
return observations
def _init_encoders(self): def _init_encoders(self):
"""Initialize shared or separate encoders for noise actor and critics.""" """Initialize shared or separate encoders for noise actor and critics."""
self.shared_encoder = self.config.shared_encoder self.shared_encoder = self.config.shared_encoder
@@ -503,11 +603,16 @@ class DSRLPolicy(PreTrainedPolicy):
def _init_noise_actor(self): def _init_noise_actor(self):
"""Initialize noise actor network and default target entropy.""" """Initialize noise actor network and default target entropy."""
# Filter out use_layer_norm since MLP doesn't accept it (MLP always uses LayerNorm)
mlp_kwargs = {
k: v for k, v in asdict(self.config.noise_actor_network_kwargs).items()
if k != "use_layer_norm"
}
self.noise_actor = NoiseActorPolicy( self.noise_actor = NoiseActorPolicy(
encoder=self.encoder_noise_actor, encoder=self.encoder_noise_actor,
network=MLP( network=MLP(
input_dim=self.encoder_noise_actor.output_dim, input_dim=self.encoder_noise_actor.output_dim,
**asdict(self.config.noise_actor_network_kwargs), **mlp_kwargs,
), ),
action_dim=self.config.output_features[ACTION].shape[0], action_dim=self.config.output_features[ACTION].shape[0],
encoder_is_shared=self.shared_encoder, encoder_is_shared=self.shared_encoder,
+10 -5
View File
@@ -86,7 +86,7 @@ class ReplayBuffer:
image_augmentation_function: Callable | None = None, image_augmentation_function: Callable | None = None,
use_drq: bool = True, use_drq: bool = True,
storage_device: str = "cpu", storage_device: str = "cpu",
optimize_memory: bool = False, optimize_memory: bool = True
): ):
""" """
Replay buffer for storing transitions. Replay buffer for storing transitions.
@@ -136,6 +136,7 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor] | None = None,
): ):
"""Initialize the storage tensors based on the first transition.""" """Initialize the storage tensors based on the first transition."""
self.capacity = 1000
# Determine shapes from the first transition # Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape action_shape = action.squeeze(0).shape
@@ -444,7 +445,7 @@ class ReplayBuffer:
if capacity is None: if capacity is None:
capacity = len(lerobot_dataset) capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset): if capacity < 1000: #len(lerobot_dataset):
raise ValueError( raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
) )
@@ -476,13 +477,14 @@ class ReplayBuffer:
and first_transition["complementary_info"] is not None and first_transition["complementary_info"] is not None
): ):
first_complementary_info = { first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items() k: v.to for k, v in first_transition["complementary_info"].items()
} }
replay_buffer._initialize_storage( replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info state=first_state, action=first_action, complementary_info=first_complementary_info
) )
num_samples = 0
# Fill the buffer with all transitions # Fill the buffer with all transitions
for data in list_transition: for data in list_transition:
for k, v in data.items(): for k, v in data.items():
@@ -503,6 +505,9 @@ class ReplayBuffer:
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None), complementary_info=data.get("complementary_info", None),
) )
num_samples += 1
if num_samples >= 1000:
return replay_buffer
return replay_buffer return replay_buffer
@@ -645,7 +650,7 @@ class ReplayBuffer:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = [] transitions = []
num_frames = len(dataset) num_frames = 1000 # len(dataset)
# Check if the dataset has "next.done" key # Check if the dataset has "next.done" key
sample = dataset[0] sample = dataset[0]
@@ -659,7 +664,7 @@ class ReplayBuffer:
if not has_done_key: if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...") print("'next.done' key not found in dataset. Inferring from episode boundaries...")
for i in tqdm(range(num_frames)): for i in tqdm(range(1000)): # num_frames)):
current_sample = dataset[i] current_sample = dataset[i]
# ----- 1) Current state ----- # ----- 1) Current state -----