mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
pi0 hack
This commit is contained in:
@@ -139,6 +139,8 @@ class DSRLConfig(PreTrainedConfig):
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Number of steps for offline training
|
||||
offline_steps: int = 100000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
|
||||
@@ -33,7 +33,13 @@ from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig, is_image_featur
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_ENV_STATE,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
VALID_ACTION_POLICIES = ["diffusion", "smolvla", "pi0", "pi05"]
|
||||
|
||||
@@ -132,7 +138,9 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
noise, _, _ = self.noise_actor(batch, observations_features)
|
||||
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||
actions = self.action_policy.predict_action_chunk(batch, noise=noise)
|
||||
# Ensure observations have language tokens if action policy requires them
|
||||
batch_with_lang = self._add_language_tokens_if_needed(batch)
|
||||
actions = self.action_policy.predict_action_chunk(batch_with_lang, noise=noise)
|
||||
return actions[:, 0, :]
|
||||
|
||||
def action_critic_forward(
|
||||
@@ -295,7 +303,9 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
# Generate next actions
|
||||
# a' = πW_dp(s', w')
|
||||
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
|
||||
# Ensure next_observations have language tokens if action policy requires them
|
||||
next_obs_with_lang = self._add_language_tokens_if_needed(next_observations)
|
||||
next_actions_chunk = self.action_policy.predict_action_chunk(next_obs_with_lang, next_noise)
|
||||
next_action_preds = next_actions_chunk[:, 0, :]
|
||||
|
||||
# Compute target Q-values: Q̄A(s', a')
|
||||
@@ -370,8 +380,10 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
|
||||
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||
with torch.no_grad():
|
||||
# Ensure observations have language tokens if action policy requires them
|
||||
obs_with_lang = self._add_language_tokens_if_needed(observations)
|
||||
# Generate action using base policy: a = πW_dp(s, w)
|
||||
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
|
||||
actions_chunk = self.action_policy.predict_action_chunk(obs_with_lang, noise=noise)
|
||||
actions = actions_chunk[:, 0, :]
|
||||
|
||||
# Get target Q-values from action critic: QA(s, a)
|
||||
@@ -443,7 +455,12 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
return noise_actor_loss
|
||||
|
||||
def _init_action_policy(self):
|
||||
"""Initialize the action policy."""
|
||||
"""Initialize the action policy and freeze it completely.
|
||||
|
||||
The action policy is a pretrained model that should never be updated
|
||||
during DSRL training. All parameters are frozen (requires_grad=False)
|
||||
and the model is set to eval mode.
|
||||
"""
|
||||
|
||||
action_policy = get_policy_class(self.config.action_policy_name)
|
||||
if self.config.action_policy_weights is not None:
|
||||
@@ -452,7 +469,90 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
self.action_policy = action_policy(self.config)
|
||||
self.action_policy.to(self.config.device)
|
||||
self.action_policy.eval()
|
||||
|
||||
# Freeze all parameters - action policy should never be updated
|
||||
for param in self.action_policy.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Cache for lazy preprocessor creation
|
||||
self._action_policy_preprocessor = None
|
||||
|
||||
def _get_action_policy_preprocessor(self):
|
||||
"""Lazily create and cache the tokenizer for the action policy.
|
||||
|
||||
The tokenizer is created on-demand when first needed to tokenize task descriptions
|
||||
for PI0/PI05 policies during training. It's cached for efficiency.
|
||||
|
||||
Note: This caches just the tokenizer, not a full preprocessor pipeline, because
|
||||
replay buffer observations don't have complementary task data that the preprocessor
|
||||
would need.
|
||||
"""
|
||||
if self._action_policy_preprocessor is not None:
|
||||
return self._action_policy_preprocessor
|
||||
|
||||
# Only create tokenizer for PI0/PI05 policies
|
||||
if self.config.action_policy_name not in ["pi0", "pi05"]:
|
||||
return None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Cache the tokenizer for the action policy
|
||||
# Using the same tokenizer as PI0 uses (following processor_pi0.py)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self._action_policy_preprocessor = tokenizer
|
||||
|
||||
return tokenizer
|
||||
|
||||
def _add_language_tokens_if_needed(self, observations: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Add language tokens to observations if missing and action policy requires them.
|
||||
|
||||
For PI0/PI05 policies, adds default language tokens using the same tokenizer
|
||||
as the action policy. This is used during training when replay buffer observations
|
||||
don't have complementary task data.
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observation tensors
|
||||
|
||||
Returns:
|
||||
Observations dict with language tokens added if needed
|
||||
"""
|
||||
tokenizer = self._get_action_policy_preprocessor()
|
||||
# If language tokens already exist, return as-is
|
||||
if OBS_LANGUAGE_TOKENS in observations and OBS_LANGUAGE_ATTENTION_MASK in observations:
|
||||
return observations
|
||||
|
||||
# Get batch properties
|
||||
device = next(iter(observations.values())).device
|
||||
batch_size = next(iter(observations.values())).shape[0]
|
||||
|
||||
# Get tokenizer max length from action policy config
|
||||
tokenizer_max_length = getattr(self.action_policy.config, "tokenizer_max_length", 48)
|
||||
|
||||
# Use default task description with newline (as PI0 expects)
|
||||
default_task = "Pick up the object\n"
|
||||
tasks = [default_task] * batch_size
|
||||
|
||||
# Tokenize tasks using the action policy's tokenizer
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Add to observations dict
|
||||
observations = observations.copy()
|
||||
observations[OBS_LANGUAGE_TOKENS] = lang_tokens
|
||||
observations[OBS_LANGUAGE_ATTENTION_MASK] = lang_masks
|
||||
|
||||
return observations
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for noise actor and critics."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
@@ -503,11 +603,16 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
def _init_noise_actor(self):
|
||||
"""Initialize noise actor network and default target entropy."""
|
||||
# Filter out use_layer_norm since MLP doesn't accept it (MLP always uses LayerNorm)
|
||||
mlp_kwargs = {
|
||||
k: v for k, v in asdict(self.config.noise_actor_network_kwargs).items()
|
||||
if k != "use_layer_norm"
|
||||
}
|
||||
self.noise_actor = NoiseActorPolicy(
|
||||
encoder=self.encoder_noise_actor,
|
||||
network=MLP(
|
||||
input_dim=self.encoder_noise_actor.output_dim,
|
||||
**asdict(self.config.noise_actor_network_kwargs),
|
||||
**mlp_kwargs,
|
||||
),
|
||||
action_dim=self.config.output_features[ACTION].shape[0],
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
|
||||
@@ -86,7 +86,7 @@ class ReplayBuffer:
|
||||
image_augmentation_function: Callable | None = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
optimize_memory: bool = True
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
@@ -136,6 +136,7 @@ class ReplayBuffer:
|
||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
self.capacity = 1000
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
@@ -444,7 +445,7 @@ class ReplayBuffer:
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
if capacity < 1000: #len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
@@ -476,13 +477,14 @@ class ReplayBuffer:
|
||||
and first_transition["complementary_info"] is not None
|
||||
):
|
||||
first_complementary_info = {
|
||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||
k: v.to for k, v in first_transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
replay_buffer._initialize_storage(
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
num_samples = 0
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
@@ -503,6 +505,9 @@ class ReplayBuffer:
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
num_samples += 1
|
||||
if num_samples >= 1000:
|
||||
return replay_buffer
|
||||
|
||||
return replay_buffer
|
||||
|
||||
@@ -645,7 +650,7 @@ class ReplayBuffer:
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
num_frames = 1000 # len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
@@ -659,7 +664,7 @@ class ReplayBuffer:
|
||||
if not has_done_key:
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
for i in tqdm(range(1000)): # num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
|
||||
Reference in New Issue
Block a user