mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
* Change Diffusion policy to use chunk_size notation instead of horizon to standerize the variable names across policies
* reshape noise after taking it as output of the network
This commit is contained in:
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
chunk_size: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# Check that the chunk size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
if self.horizon % downsampling_factor != 0:
|
||||
if self.chunk_size % downsampling_factor != 0:
|
||||
raise ValueError(
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
underlying diffusion model. Here's how it works:
|
||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||
copied `n_obs_steps` times to fill the cache).
|
||||
- The diffusion model generates `horizon` steps worth of actions.
|
||||
- The diffusion model generates `chunk_size` steps worth of actions.
|
||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
|
||||
this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
|
||||
noise
|
||||
if noise is not None
|
||||
else torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
|
||||
AND/OR
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
"action": (B, chunk_size, action_dim)
|
||||
"action_is_pad": (B, chunk_size)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch[ACTION].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
chunk_size = batch[ACTION].shape[1]
|
||||
assert chunk_size == self.config.chunk_size
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
|
||||
@@ -75,7 +75,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DSRLConfig | None = None,
|
||||
config: DSRLConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -91,11 +91,16 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
self._init_noise_actor()
|
||||
self._init_temperature()
|
||||
|
||||
# Set chunk size for action policy
|
||||
if not hasattr(self.action_policy.config, "chunk_size"):
|
||||
raise ValueError("Action policy config does not have a chunk size")
|
||||
self.chunk_size = self.action_policy.config.chunk_size
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"noise_actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
for n, p in self.noise_actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic_action": self.action_critic_ensemble.parameters(),
|
||||
@@ -109,12 +114,12 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!")
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||
"""Select noise vector for inference/evaluation,
|
||||
pass it through the action policy to get the action.
|
||||
|
||||
@@ -126,8 +131,9 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
observations_features = self.noise_actor.encoder.get_cached_image_features(batch)
|
||||
|
||||
noise, _, _ = self.noise_actor(batch, observations_features)
|
||||
|
||||
return self.action_policy(batch, noise)
|
||||
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||
actions = self.action_policy.predict_action_chunk(batch, noise=noise)
|
||||
return actions[:, 0, :]
|
||||
|
||||
def action_critic_forward(
|
||||
self,
|
||||
@@ -203,7 +209,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
observation_features: Tensor | None = batch.get("observation_feature")
|
||||
|
||||
if model == "critic_action":
|
||||
# 1. Action Critic: TD-learning on action space
|
||||
# Action Critic: TD-learning on action space
|
||||
# Extract critic-specific components
|
||||
actions: Tensor = batch[ACTION]
|
||||
rewards: Tensor = batch["reward"]
|
||||
@@ -223,7 +229,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "critic_noise":
|
||||
# 2. Noise Critic: Distillation from action critic
|
||||
# Noise Critic: Distillation from action critic
|
||||
loss_critic_noise = self.compute_loss_critic_noise(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
@@ -231,7 +237,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
return {"loss_critic_noise": loss_critic_noise}
|
||||
|
||||
if model == "noise_actor":
|
||||
# 3. Noise Actor: Maximize Q-values in noise space
|
||||
# Noise Actor: Maximize Q-values in noise space
|
||||
loss_noise_actor = self.compute_loss_noise_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
@@ -283,14 +289,16 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# 1. Sample noise from noise actor: w' ~ πW(s')
|
||||
# Sample noise from noise actor: w' ~ πW(s')
|
||||
next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features)
|
||||
next_noise = next_noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||
|
||||
# 2. Generate next actions
|
||||
# Generate next actions
|
||||
# a' = πW_dp(s', w')
|
||||
next_action_preds = self.action_policy(next_observations, next_noise)
|
||||
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
|
||||
next_action_preds = next_actions_chunk[:, 0, :]
|
||||
|
||||
# 3. Compute target Q-values: Q̄A(s', a')
|
||||
# Compute target Q-values: Q̄A(s', a')
|
||||
q_targets = self.action_critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
@@ -312,7 +320,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
# compute predicted qs
|
||||
q_preds = self.action_critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
@@ -320,7 +328,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
@@ -358,14 +366,15 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
batch_size = next(iter(observations.values())).shape[0]
|
||||
action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
# 1. Sample noise w ~ N(0, I)
|
||||
# Sample noise w ~ N(0, I)
|
||||
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
|
||||
|
||||
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||
with torch.no_grad():
|
||||
# 2. Generate action using base policy: a = πW_dp(s, w)
|
||||
actions = self.action_policy(observations, noise)
|
||||
# Generate action using base policy: a = πW_dp(s, w)
|
||||
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
|
||||
actions = actions_chunk[:, 0, :]
|
||||
|
||||
# 3. Get target Q-values from action critic: QA(s, a)
|
||||
# Get target Q-values from action critic: QA(s, a)
|
||||
q_targets = self.action_critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
@@ -375,14 +384,14 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
# Average over ensemble critics
|
||||
q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size)
|
||||
|
||||
# 4. Get predicted Q-values from noise critic: QW(s, w)
|
||||
# Get predicted Q-values from noise critic: QW(s, w)
|
||||
q_preds = self.noise_critic_forward(
|
||||
observations=observations,
|
||||
noise=noise,
|
||||
observation_features=observation_features,
|
||||
) # (batch_size, 1)
|
||||
|
||||
# 5. Compute MSE loss
|
||||
# Compute MSE loss
|
||||
loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0))
|
||||
|
||||
return loss
|
||||
@@ -417,17 +426,17 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
Returns:
|
||||
Noise actor loss with entropy regularization
|
||||
"""
|
||||
# 1. Sample noise w ~ πW(s) from noise actor
|
||||
# Sample noise w ~ πW(s) from noise actor
|
||||
noise, log_probs, _ = self.noise_actor(observations, observation_features)
|
||||
|
||||
# 2. Evaluate QW(s, w) using noise critic
|
||||
# Evaluate QW(s, w) using noise critic
|
||||
q_values = self.noise_critic_forward(
|
||||
observations=observations,
|
||||
noise=noise,
|
||||
observation_features=observation_features,
|
||||
) # (batch_size, 1)
|
||||
|
||||
# 3. Compute loss: minimize (temperature * log_prob - Q_value)
|
||||
# Compute loss: minimize (temperature * log_prob - Q_value)
|
||||
# This is equivalent to maximizing (Q_value - temperature * log_prob)
|
||||
noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean()
|
||||
|
||||
@@ -437,7 +446,10 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
"""Initialize the action policy."""
|
||||
|
||||
action_policy = get_policy_class(self.config.action_policy_name)
|
||||
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
|
||||
if self.config.action_policy_weights is not None:
|
||||
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
|
||||
else:
|
||||
self.action_policy = action_policy(self.config)
|
||||
self.action_policy.to(self.config.device)
|
||||
self.action_policy.eval()
|
||||
|
||||
@@ -491,7 +503,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
||||
|
||||
def _init_noise_actor(self):
|
||||
"""Initialize noise actor network and default target entropy."""
|
||||
self.noise_actor = Policy(
|
||||
self.noise_actor = NoiseActorPolicy(
|
||||
encoder=self.encoder_noise_actor,
|
||||
network=MLP(
|
||||
input_dim=self.encoder_noise_actor.output_dim,
|
||||
@@ -704,11 +716,9 @@ class MLP(nn.Module):
|
||||
total = len(hidden_dims)
|
||||
|
||||
for idx, out_dim in enumerate(hidden_dims):
|
||||
# 1) linear transform
|
||||
layers.append(nn.Linear(in_dim, out_dim))
|
||||
|
||||
is_last = idx == total - 1
|
||||
# 2-4) optionally add dropout, normalization, and activation
|
||||
if not is_last or activate_final:
|
||||
if dropout_rate and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
@@ -805,7 +815,7 @@ class CriticEnsemble(nn.Module):
|
||||
return q_values
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
class NoiseActorPolicy(nn.Module):
|
||||
"""Noise Actor (πW) that maps states to noise distributions.
|
||||
|
||||
This is the noise actor πW: S → △W that outputs noise vectors in the latent-noise space.
|
||||
|
||||
Reference in New Issue
Block a user