mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
pi0 hack
This commit is contained in:
@@ -139,6 +139,8 @@ class DSRLConfig(PreTrainedConfig):
|
|||||||
# Training parameter
|
# Training parameter
|
||||||
# Number of steps for online training
|
# Number of steps for online training
|
||||||
online_steps: int = 1000000
|
online_steps: int = 1000000
|
||||||
|
# Number of steps for offline training
|
||||||
|
offline_steps: int = 100000
|
||||||
# Capacity of the online replay buffer
|
# Capacity of the online replay buffer
|
||||||
online_buffer_capacity: int = 100000
|
online_buffer_capacity: int = 100000
|
||||||
# Capacity of the offline replay buffer
|
# Capacity of the offline replay buffer
|
||||||
|
|||||||
@@ -33,7 +33,13 @@ from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig, is_image_featur
|
|||||||
from lerobot.policies.factory import get_policy_class
|
from lerobot.policies.factory import get_policy_class
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
from lerobot.utils.constants import (
|
||||||
|
ACTION,
|
||||||
|
OBS_ENV_STATE,
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_STATE,
|
||||||
|
)
|
||||||
|
|
||||||
VALID_ACTION_POLICIES = ["diffusion", "smolvla", "pi0", "pi05"]
|
VALID_ACTION_POLICIES = ["diffusion", "smolvla", "pi0", "pi05"]
|
||||||
|
|
||||||
@@ -132,7 +138,9 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
noise, _, _ = self.noise_actor(batch, observations_features)
|
noise, _, _ = self.noise_actor(batch, observations_features)
|
||||||
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||||
actions = self.action_policy.predict_action_chunk(batch, noise=noise)
|
# Ensure observations have language tokens if action policy requires them
|
||||||
|
batch_with_lang = self._add_language_tokens_if_needed(batch)
|
||||||
|
actions = self.action_policy.predict_action_chunk(batch_with_lang, noise=noise)
|
||||||
return actions[:, 0, :]
|
return actions[:, 0, :]
|
||||||
|
|
||||||
def action_critic_forward(
|
def action_critic_forward(
|
||||||
@@ -295,7 +303,9 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Generate next actions
|
# Generate next actions
|
||||||
# a' = πW_dp(s', w')
|
# a' = πW_dp(s', w')
|
||||||
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
|
# Ensure next_observations have language tokens if action policy requires them
|
||||||
|
next_obs_with_lang = self._add_language_tokens_if_needed(next_observations)
|
||||||
|
next_actions_chunk = self.action_policy.predict_action_chunk(next_obs_with_lang, next_noise)
|
||||||
next_action_preds = next_actions_chunk[:, 0, :]
|
next_action_preds = next_actions_chunk[:, 0, :]
|
||||||
|
|
||||||
# Compute target Q-values: Q̄A(s', a')
|
# Compute target Q-values: Q̄A(s', a')
|
||||||
@@ -370,8 +380,10 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
|
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
|
||||||
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Ensure observations have language tokens if action policy requires them
|
||||||
|
obs_with_lang = self._add_language_tokens_if_needed(observations)
|
||||||
# Generate action using base policy: a = πW_dp(s, w)
|
# Generate action using base policy: a = πW_dp(s, w)
|
||||||
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
|
actions_chunk = self.action_policy.predict_action_chunk(obs_with_lang, noise=noise)
|
||||||
actions = actions_chunk[:, 0, :]
|
actions = actions_chunk[:, 0, :]
|
||||||
|
|
||||||
# Get target Q-values from action critic: QA(s, a)
|
# Get target Q-values from action critic: QA(s, a)
|
||||||
@@ -443,7 +455,12 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
return noise_actor_loss
|
return noise_actor_loss
|
||||||
|
|
||||||
def _init_action_policy(self):
|
def _init_action_policy(self):
|
||||||
"""Initialize the action policy."""
|
"""Initialize the action policy and freeze it completely.
|
||||||
|
|
||||||
|
The action policy is a pretrained model that should never be updated
|
||||||
|
during DSRL training. All parameters are frozen (requires_grad=False)
|
||||||
|
and the model is set to eval mode.
|
||||||
|
"""
|
||||||
|
|
||||||
action_policy = get_policy_class(self.config.action_policy_name)
|
action_policy = get_policy_class(self.config.action_policy_name)
|
||||||
if self.config.action_policy_weights is not None:
|
if self.config.action_policy_weights is not None:
|
||||||
@@ -452,7 +469,90 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
self.action_policy = action_policy(self.config)
|
self.action_policy = action_policy(self.config)
|
||||||
self.action_policy.to(self.config.device)
|
self.action_policy.to(self.config.device)
|
||||||
self.action_policy.eval()
|
self.action_policy.eval()
|
||||||
|
|
||||||
|
# Freeze all parameters - action policy should never be updated
|
||||||
|
for param in self.action_policy.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Cache for lazy preprocessor creation
|
||||||
|
self._action_policy_preprocessor = None
|
||||||
|
|
||||||
|
def _get_action_policy_preprocessor(self):
|
||||||
|
"""Lazily create and cache the tokenizer for the action policy.
|
||||||
|
|
||||||
|
The tokenizer is created on-demand when first needed to tokenize task descriptions
|
||||||
|
for PI0/PI05 policies during training. It's cached for efficiency.
|
||||||
|
|
||||||
|
Note: This caches just the tokenizer, not a full preprocessor pipeline, because
|
||||||
|
replay buffer observations don't have complementary task data that the preprocessor
|
||||||
|
would need.
|
||||||
|
"""
|
||||||
|
if self._action_policy_preprocessor is not None:
|
||||||
|
return self._action_policy_preprocessor
|
||||||
|
|
||||||
|
# Only create tokenizer for PI0/PI05 policies
|
||||||
|
if self.config.action_policy_name not in ["pi0", "pi05"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
# Cache the tokenizer for the action policy
|
||||||
|
# Using the same tokenizer as PI0 uses (following processor_pi0.py)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||||
|
self._action_policy_preprocessor = tokenizer
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def _add_language_tokens_if_needed(self, observations: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
"""Add language tokens to observations if missing and action policy requires them.
|
||||||
|
|
||||||
|
For PI0/PI05 policies, adds default language tokens using the same tokenizer
|
||||||
|
as the action policy. This is used during training when replay buffer observations
|
||||||
|
don't have complementary task data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observation tensors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Observations dict with language tokens added if needed
|
||||||
|
"""
|
||||||
|
tokenizer = self._get_action_policy_preprocessor()
|
||||||
|
# If language tokens already exist, return as-is
|
||||||
|
if OBS_LANGUAGE_TOKENS in observations and OBS_LANGUAGE_ATTENTION_MASK in observations:
|
||||||
|
return observations
|
||||||
|
|
||||||
|
# Get batch properties
|
||||||
|
device = next(iter(observations.values())).device
|
||||||
|
batch_size = next(iter(observations.values())).shape[0]
|
||||||
|
|
||||||
|
# Get tokenizer max length from action policy config
|
||||||
|
tokenizer_max_length = getattr(self.action_policy.config, "tokenizer_max_length", 48)
|
||||||
|
|
||||||
|
# Use default task description with newline (as PI0 expects)
|
||||||
|
default_task = "Pick up the object\n"
|
||||||
|
tasks = [default_task] * batch_size
|
||||||
|
|
||||||
|
# Tokenize tasks using the action policy's tokenizer
|
||||||
|
tokenized = tokenizer(
|
||||||
|
tasks,
|
||||||
|
padding="max_length",
|
||||||
|
padding_side="right",
|
||||||
|
truncation=True,
|
||||||
|
max_length=tokenizer_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
lang_tokens = tokenized["input_ids"].to(device)
|
||||||
|
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Add to observations dict
|
||||||
|
observations = observations.copy()
|
||||||
|
observations[OBS_LANGUAGE_TOKENS] = lang_tokens
|
||||||
|
observations[OBS_LANGUAGE_ATTENTION_MASK] = lang_masks
|
||||||
|
|
||||||
|
return observations
|
||||||
|
|
||||||
def _init_encoders(self):
|
def _init_encoders(self):
|
||||||
"""Initialize shared or separate encoders for noise actor and critics."""
|
"""Initialize shared or separate encoders for noise actor and critics."""
|
||||||
self.shared_encoder = self.config.shared_encoder
|
self.shared_encoder = self.config.shared_encoder
|
||||||
@@ -503,11 +603,16 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def _init_noise_actor(self):
|
def _init_noise_actor(self):
|
||||||
"""Initialize noise actor network and default target entropy."""
|
"""Initialize noise actor network and default target entropy."""
|
||||||
|
# Filter out use_layer_norm since MLP doesn't accept it (MLP always uses LayerNorm)
|
||||||
|
mlp_kwargs = {
|
||||||
|
k: v for k, v in asdict(self.config.noise_actor_network_kwargs).items()
|
||||||
|
if k != "use_layer_norm"
|
||||||
|
}
|
||||||
self.noise_actor = NoiseActorPolicy(
|
self.noise_actor = NoiseActorPolicy(
|
||||||
encoder=self.encoder_noise_actor,
|
encoder=self.encoder_noise_actor,
|
||||||
network=MLP(
|
network=MLP(
|
||||||
input_dim=self.encoder_noise_actor.output_dim,
|
input_dim=self.encoder_noise_actor.output_dim,
|
||||||
**asdict(self.config.noise_actor_network_kwargs),
|
**mlp_kwargs,
|
||||||
),
|
),
|
||||||
action_dim=self.config.output_features[ACTION].shape[0],
|
action_dim=self.config.output_features[ACTION].shape[0],
|
||||||
encoder_is_shared=self.shared_encoder,
|
encoder_is_shared=self.shared_encoder,
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class ReplayBuffer:
|
|||||||
image_augmentation_function: Callable | None = None,
|
image_augmentation_function: Callable | None = None,
|
||||||
use_drq: bool = True,
|
use_drq: bool = True,
|
||||||
storage_device: str = "cpu",
|
storage_device: str = "cpu",
|
||||||
optimize_memory: bool = False,
|
optimize_memory: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Replay buffer for storing transitions.
|
Replay buffer for storing transitions.
|
||||||
@@ -136,6 +136,7 @@ class ReplayBuffer:
|
|||||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the storage tensors based on the first transition."""
|
"""Initialize the storage tensors based on the first transition."""
|
||||||
|
self.capacity = 1000
|
||||||
# Determine shapes from the first transition
|
# Determine shapes from the first transition
|
||||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||||
action_shape = action.squeeze(0).shape
|
action_shape = action.squeeze(0).shape
|
||||||
@@ -444,7 +445,7 @@ class ReplayBuffer:
|
|||||||
if capacity is None:
|
if capacity is None:
|
||||||
capacity = len(lerobot_dataset)
|
capacity = len(lerobot_dataset)
|
||||||
|
|
||||||
if capacity < len(lerobot_dataset):
|
if capacity < 1000: #len(lerobot_dataset):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||||
)
|
)
|
||||||
@@ -476,13 +477,14 @@ class ReplayBuffer:
|
|||||||
and first_transition["complementary_info"] is not None
|
and first_transition["complementary_info"] is not None
|
||||||
):
|
):
|
||||||
first_complementary_info = {
|
first_complementary_info = {
|
||||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
k: v.to for k, v in first_transition["complementary_info"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
replay_buffer._initialize_storage(
|
replay_buffer._initialize_storage(
|
||||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_samples = 0
|
||||||
# Fill the buffer with all transitions
|
# Fill the buffer with all transitions
|
||||||
for data in list_transition:
|
for data in list_transition:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
@@ -503,6 +505,9 @@ class ReplayBuffer:
|
|||||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||||
complementary_info=data.get("complementary_info", None),
|
complementary_info=data.get("complementary_info", None),
|
||||||
)
|
)
|
||||||
|
num_samples += 1
|
||||||
|
if num_samples >= 1000:
|
||||||
|
return replay_buffer
|
||||||
|
|
||||||
return replay_buffer
|
return replay_buffer
|
||||||
|
|
||||||
@@ -645,7 +650,7 @@ class ReplayBuffer:
|
|||||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||||
|
|
||||||
transitions = []
|
transitions = []
|
||||||
num_frames = len(dataset)
|
num_frames = 1000 # len(dataset)
|
||||||
|
|
||||||
# Check if the dataset has "next.done" key
|
# Check if the dataset has "next.done" key
|
||||||
sample = dataset[0]
|
sample = dataset[0]
|
||||||
@@ -659,7 +664,7 @@ class ReplayBuffer:
|
|||||||
if not has_done_key:
|
if not has_done_key:
|
||||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||||
|
|
||||||
for i in tqdm(range(num_frames)):
|
for i in tqdm(range(1000)): # num_frames)):
|
||||||
current_sample = dataset[i]
|
current_sample = dataset[i]
|
||||||
|
|
||||||
# ----- 1) Current state -----
|
# ----- 1) Current state -----
|
||||||
|
|||||||
Reference in New Issue
Block a user