This commit is contained in:
Michel Aractingi
2025-10-30 14:17:36 +01:00
parent 1594ae60a7
commit e5e1c97a6c
3 changed files with 123 additions and 11 deletions
@@ -139,6 +139,8 @@ class DSRLConfig(PreTrainedConfig):
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Number of steps for offline training
offline_steps: int = 100000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
+111 -6
View File
@@ -33,7 +33,13 @@ from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig, is_image_featur
from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from lerobot.utils.constants import (
ACTION,
OBS_ENV_STATE,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
VALID_ACTION_POLICIES = ["diffusion", "smolvla", "pi0", "pi05"]
@@ -132,7 +138,9 @@ class DSRLPolicy(PreTrainedPolicy):
noise, _, _ = self.noise_actor(batch, observations_features)
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
actions = self.action_policy.predict_action_chunk(batch, noise=noise)
# Ensure observations have language tokens if action policy requires them
batch_with_lang = self._add_language_tokens_if_needed(batch)
actions = self.action_policy.predict_action_chunk(batch_with_lang, noise=noise)
return actions[:, 0, :]
def action_critic_forward(
@@ -295,7 +303,9 @@ class DSRLPolicy(PreTrainedPolicy):
# Generate next actions
# a' = πW_dp(s', w')
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
# Ensure next_observations have language tokens if action policy requires them
next_obs_with_lang = self._add_language_tokens_if_needed(next_observations)
next_actions_chunk = self.action_policy.predict_action_chunk(next_obs_with_lang, next_noise)
next_action_preds = next_actions_chunk[:, 0, :]
# Compute target Q-values: Q̄A(s', a')
@@ -370,8 +380,10 @@ class DSRLPolicy(PreTrainedPolicy):
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
with torch.no_grad():
# Ensure observations have language tokens if action policy requires them
obs_with_lang = self._add_language_tokens_if_needed(observations)
# Generate action using base policy: a = πW_dp(s, w)
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
actions_chunk = self.action_policy.predict_action_chunk(obs_with_lang, noise=noise)
actions = actions_chunk[:, 0, :]
# Get target Q-values from action critic: QA(s, a)
@@ -443,7 +455,12 @@ class DSRLPolicy(PreTrainedPolicy):
return noise_actor_loss
def _init_action_policy(self):
"""Initialize the action policy."""
"""Initialize the action policy and freeze it completely.
The action policy is a pretrained model that should never be updated
during DSRL training. All parameters are frozen (requires_grad=False)
and the model is set to eval mode.
"""
action_policy = get_policy_class(self.config.action_policy_name)
if self.config.action_policy_weights is not None:
@@ -452,7 +469,90 @@ class DSRLPolicy(PreTrainedPolicy):
self.action_policy = action_policy(self.config)
self.action_policy.to(self.config.device)
self.action_policy.eval()
# Freeze all parameters - action policy should never be updated
for param in self.action_policy.parameters():
param.requires_grad = False
# Cache for lazy preprocessor creation
self._action_policy_preprocessor = None
def _get_action_policy_preprocessor(self):
"""Lazily create and cache the tokenizer for the action policy.
The tokenizer is created on-demand when first needed to tokenize task descriptions
for PI0/PI05 policies during training. It's cached for efficiency.
Note: This caches just the tokenizer, not a full preprocessor pipeline, because
replay buffer observations don't have complementary task data that the preprocessor
would need.
"""
if self._action_policy_preprocessor is not None:
return self._action_policy_preprocessor
# Only create tokenizer for PI0/PI05 policies
if self.config.action_policy_name not in ["pi0", "pi05"]:
return None
# Import here to avoid circular imports
from transformers import AutoTokenizer
# Cache the tokenizer for the action policy
# Using the same tokenizer as PI0 uses (following processor_pi0.py)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self._action_policy_preprocessor = tokenizer
return tokenizer
def _add_language_tokens_if_needed(self, observations: dict[str, Tensor]) -> dict[str, Tensor]:
"""Add language tokens to observations if missing and action policy requires them.
For PI0/PI05 policies, adds default language tokens using the same tokenizer
as the action policy. This is used during training when replay buffer observations
don't have complementary task data.
Args:
observations: Dictionary of observation tensors
Returns:
Observations dict with language tokens added if needed
"""
tokenizer = self._get_action_policy_preprocessor()
# If language tokens already exist, return as-is
if OBS_LANGUAGE_TOKENS in observations and OBS_LANGUAGE_ATTENTION_MASK in observations:
return observations
# Get batch properties
device = next(iter(observations.values())).device
batch_size = next(iter(observations.values())).shape[0]
# Get tokenizer max length from action policy config
tokenizer_max_length = getattr(self.action_policy.config, "tokenizer_max_length", 48)
# Use default task description with newline (as PI0 expects)
default_task = "Pick up the object\n"
tasks = [default_task] * batch_size
# Tokenize tasks using the action policy's tokenizer
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Add to observations dict
observations = observations.copy()
observations[OBS_LANGUAGE_TOKENS] = lang_tokens
observations[OBS_LANGUAGE_ATTENTION_MASK] = lang_masks
return observations
def _init_encoders(self):
"""Initialize shared or separate encoders for noise actor and critics."""
self.shared_encoder = self.config.shared_encoder
@@ -503,11 +603,16 @@ class DSRLPolicy(PreTrainedPolicy):
def _init_noise_actor(self):
"""Initialize noise actor network and default target entropy."""
# Filter out use_layer_norm since MLP doesn't accept it (MLP always uses LayerNorm)
mlp_kwargs = {
k: v for k, v in asdict(self.config.noise_actor_network_kwargs).items()
if k != "use_layer_norm"
}
self.noise_actor = NoiseActorPolicy(
encoder=self.encoder_noise_actor,
network=MLP(
input_dim=self.encoder_noise_actor.output_dim,
**asdict(self.config.noise_actor_network_kwargs),
**mlp_kwargs,
),
action_dim=self.config.output_features[ACTION].shape[0],
encoder_is_shared=self.shared_encoder,
+10 -5
View File
@@ -86,7 +86,7 @@ class ReplayBuffer:
image_augmentation_function: Callable | None = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
optimize_memory: bool = True
):
"""
Replay buffer for storing transitions.
@@ -136,6 +136,7 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None,
):
"""Initialize the storage tensors based on the first transition."""
self.capacity = 1000
# Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape
@@ -444,7 +445,7 @@ class ReplayBuffer:
if capacity is None:
capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset):
if capacity < 1000: #len(lerobot_dataset):
raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
@@ -476,13 +477,14 @@ class ReplayBuffer:
and first_transition["complementary_info"] is not None
):
first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items()
k: v.to for k, v in first_transition["complementary_info"].items()
}
replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info
)
num_samples = 0
# Fill the buffer with all transitions
for data in list_transition:
for k, v in data.items():
@@ -503,6 +505,9 @@ class ReplayBuffer:
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
)
num_samples += 1
if num_samples >= 1000:
return replay_buffer
return replay_buffer
@@ -645,7 +650,7 @@ class ReplayBuffer:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset)
num_frames = 1000 # len(dataset)
# Check if the dataset has "next.done" key
sample = dataset[0]
@@ -659,7 +664,7 @@ class ReplayBuffer:
if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
for i in tqdm(range(num_frames)):
for i in tqdm(range(1000)): # num_frames)):
current_sample = dataset[i]
# ----- 1) Current state -----