* Change Diffusion policy to use chunk_size notation instead of horizon to standerize the variable names across policies

* reshape noise after taking it as output of the network
This commit is contained in:
Michel Aractingi
2025-10-29 15:22:27 +01:00
parent 7cd710857d
commit 1594ae60a7
3 changed files with 58 additions and 48 deletions
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy. n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details. See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
# Inputs / output structure. # Inputs / output structure.
n_obs_steps: int = 2 n_obs_steps: int = 2
horizon: int = 16 chunk_size: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
# The original implementation doesn't sample frames for the last 7 steps, # The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results. # which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1 drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}." f"Got {self.noise_scheduler_type}."
) )
# Check that the horizon size and U-Net downsampling is compatible. # Check that the chunk size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage. # U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims) downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0: if self.chunk_size % downsampling_factor != 0:
raise ValueError( raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined " "The chunk_size should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
) )
def get_optimizer_preset(self) -> AdamConfig: def get_optimizer_preset(self) -> AdamConfig:
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
@property @property
def action_delta_indices(self) -> list: def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
@property @property
def reward_delta_indices(self) -> None: def reward_delta_indices(self) -> None:
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
return actions return actions
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works: underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache). copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions. - The diffusion model generates `chunk_size` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step. - `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like: Schematically this looks like:
---------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps) (legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h | |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO | |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
---------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
noise noise
if noise is not None if noise is not None
else torch.randn( else torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
dtype=dtype, dtype=dtype,
device=device, device=device,
generator=generator, generator=generator,
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
AND/OR AND/OR
"observation.environment_state": (B, n_obs_steps, environment_dim) "observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim) "action": (B, chunk_size, action_dim)
"action_is_pad": (B, horizon) "action_is_pad": (B, chunk_size)
} }
""" """
# Input validation. # Input validation.
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1] n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch[ACTION].shape[1] chunk_size = batch[ACTION].shape[1]
assert horizon == self.config.horizon assert chunk_size == self.config.chunk_size
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector. # Encode image features and concatenate them all together along with the state vector.
+40 -30
View File
@@ -75,7 +75,7 @@ class DSRLPolicy(PreTrainedPolicy):
def __init__( def __init__(
self, self,
config: DSRLConfig | None = None, config: DSRLConfig,
): ):
super().__init__(config) super().__init__(config)
config.validate_features() config.validate_features()
@@ -91,11 +91,16 @@ class DSRLPolicy(PreTrainedPolicy):
self._init_noise_actor() self._init_noise_actor()
self._init_temperature() self._init_temperature()
# Set chunk size for action policy
if not hasattr(self.action_policy.config, "chunk_size"):
raise ValueError("Action policy config does not have a chunk size")
self.chunk_size = self.action_policy.config.chunk_size
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
optim_params = { optim_params = {
"noise_actor": [ "noise_actor": [
p p
for n, p in self.actor.named_parameters() for n, p in self.noise_actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder if not n.startswith("encoder") or not self.shared_encoder
], ],
"critic_action": self.action_critic_ensemble.parameters(), "critic_action": self.action_critic_ensemble.parameters(),
@@ -109,12 +114,12 @@ class DSRLPolicy(PreTrainedPolicy):
pass pass
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!") raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!")
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Select noise vector for inference/evaluation, """Select noise vector for inference/evaluation,
pass it through the action policy to get the action. pass it through the action policy to get the action.
@@ -126,8 +131,9 @@ class DSRLPolicy(PreTrainedPolicy):
observations_features = self.noise_actor.encoder.get_cached_image_features(batch) observations_features = self.noise_actor.encoder.get_cached_image_features(batch)
noise, _, _ = self.noise_actor(batch, observations_features) noise, _, _ = self.noise_actor(batch, observations_features)
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
return self.action_policy(batch, noise) actions = self.action_policy.predict_action_chunk(batch, noise=noise)
return actions[:, 0, :]
def action_critic_forward( def action_critic_forward(
self, self,
@@ -203,7 +209,7 @@ class DSRLPolicy(PreTrainedPolicy):
observation_features: Tensor | None = batch.get("observation_feature") observation_features: Tensor | None = batch.get("observation_feature")
if model == "critic_action": if model == "critic_action":
# 1. Action Critic: TD-learning on action space # Action Critic: TD-learning on action space
# Extract critic-specific components # Extract critic-specific components
actions: Tensor = batch[ACTION] actions: Tensor = batch[ACTION]
rewards: Tensor = batch["reward"] rewards: Tensor = batch["reward"]
@@ -223,7 +229,7 @@ class DSRLPolicy(PreTrainedPolicy):
return {"loss_critic": loss_critic} return {"loss_critic": loss_critic}
if model == "critic_noise": if model == "critic_noise":
# 2. Noise Critic: Distillation from action critic # Noise Critic: Distillation from action critic
loss_critic_noise = self.compute_loss_critic_noise( loss_critic_noise = self.compute_loss_critic_noise(
observations=observations, observations=observations,
observation_features=observation_features, observation_features=observation_features,
@@ -231,7 +237,7 @@ class DSRLPolicy(PreTrainedPolicy):
return {"loss_critic_noise": loss_critic_noise} return {"loss_critic_noise": loss_critic_noise}
if model == "noise_actor": if model == "noise_actor":
# 3. Noise Actor: Maximize Q-values in noise space # Noise Actor: Maximize Q-values in noise space
loss_noise_actor = self.compute_loss_noise_actor( loss_noise_actor = self.compute_loss_noise_actor(
observations=observations, observations=observations,
observation_features=observation_features, observation_features=observation_features,
@@ -283,14 +289,16 @@ class DSRLPolicy(PreTrainedPolicy):
""" """
with torch.no_grad(): with torch.no_grad():
# 1. Sample noise from noise actor: w' ~ πW(s') # Sample noise from noise actor: w' ~ πW(s')
next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features) next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features)
next_noise = next_noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
# 2. Generate next actions # Generate next actions
# a' = πW_dp(s', w') # a' = πW_dp(s', w')
next_action_preds = self.action_policy(next_observations, next_noise) next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
next_action_preds = next_actions_chunk[:, 0, :]
# 3. Compute target Q-values: Q̄A(s', a') # Compute target Q-values: Q̄A(s', a')
q_targets = self.action_critic_forward( q_targets = self.action_critic_forward(
observations=next_observations, observations=next_observations,
actions=next_action_preds, actions=next_action_preds,
@@ -312,7 +320,7 @@ class DSRLPolicy(PreTrainedPolicy):
td_target = rewards + (1 - done) * self.config.discount * min_q td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs # compute predicted qs
q_preds = self.action_critic_forward( q_preds = self.action_critic_forward(
observations=observations, observations=observations,
actions=actions, actions=actions,
@@ -320,7 +328,7 @@ class DSRLPolicy(PreTrainedPolicy):
observation_features=observation_features, observation_features=observation_features,
) )
# 4- Calculate loss # Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
@@ -358,14 +366,15 @@ class DSRLPolicy(PreTrainedPolicy):
batch_size = next(iter(observations.values())).shape[0] batch_size = next(iter(observations.values())).shape[0]
action_dim = self.config.output_features[ACTION].shape[0] action_dim = self.config.output_features[ACTION].shape[0]
# 1. Sample noise w ~ N(0, I) # Sample noise w ~ N(0, I)
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self)) noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
with torch.no_grad(): with torch.no_grad():
# 2. Generate action using base policy: a = πW_dp(s, w) # Generate action using base policy: a = πW_dp(s, w)
actions = self.action_policy(observations, noise) actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
actions = actions_chunk[:, 0, :]
# 3. Get target Q-values from action critic: QA(s, a) # Get target Q-values from action critic: QA(s, a)
q_targets = self.action_critic_forward( q_targets = self.action_critic_forward(
observations=observations, observations=observations,
actions=actions, actions=actions,
@@ -375,14 +384,14 @@ class DSRLPolicy(PreTrainedPolicy):
# Average over ensemble critics # Average over ensemble critics
q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size) q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size)
# 4. Get predicted Q-values from noise critic: QW(s, w) # Get predicted Q-values from noise critic: QW(s, w)
q_preds = self.noise_critic_forward( q_preds = self.noise_critic_forward(
observations=observations, observations=observations,
noise=noise, noise=noise,
observation_features=observation_features, observation_features=observation_features,
) # (batch_size, 1) ) # (batch_size, 1)
# 5. Compute MSE loss # Compute MSE loss
loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0)) loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0))
return loss return loss
@@ -417,17 +426,17 @@ class DSRLPolicy(PreTrainedPolicy):
Returns: Returns:
Noise actor loss with entropy regularization Noise actor loss with entropy regularization
""" """
# 1. Sample noise w ~ πW(s) from noise actor # Sample noise w ~ πW(s) from noise actor
noise, log_probs, _ = self.noise_actor(observations, observation_features) noise, log_probs, _ = self.noise_actor(observations, observation_features)
# 2. Evaluate QW(s, w) using noise critic # Evaluate QW(s, w) using noise critic
q_values = self.noise_critic_forward( q_values = self.noise_critic_forward(
observations=observations, observations=observations,
noise=noise, noise=noise,
observation_features=observation_features, observation_features=observation_features,
) # (batch_size, 1) ) # (batch_size, 1)
# 3. Compute loss: minimize (temperature * log_prob - Q_value) # Compute loss: minimize (temperature * log_prob - Q_value)
# This is equivalent to maximizing (Q_value - temperature * log_prob) # This is equivalent to maximizing (Q_value - temperature * log_prob)
noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean() noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean()
@@ -437,7 +446,10 @@ class DSRLPolicy(PreTrainedPolicy):
"""Initialize the action policy.""" """Initialize the action policy."""
action_policy = get_policy_class(self.config.action_policy_name) action_policy = get_policy_class(self.config.action_policy_name)
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights) if self.config.action_policy_weights is not None:
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
else:
self.action_policy = action_policy(self.config)
self.action_policy.to(self.config.device) self.action_policy.to(self.config.device)
self.action_policy.eval() self.action_policy.eval()
@@ -491,7 +503,7 @@ class DSRLPolicy(PreTrainedPolicy):
def _init_noise_actor(self): def _init_noise_actor(self):
"""Initialize noise actor network and default target entropy.""" """Initialize noise actor network and default target entropy."""
self.noise_actor = Policy( self.noise_actor = NoiseActorPolicy(
encoder=self.encoder_noise_actor, encoder=self.encoder_noise_actor,
network=MLP( network=MLP(
input_dim=self.encoder_noise_actor.output_dim, input_dim=self.encoder_noise_actor.output_dim,
@@ -704,11 +716,9 @@ class MLP(nn.Module):
total = len(hidden_dims) total = len(hidden_dims)
for idx, out_dim in enumerate(hidden_dims): for idx, out_dim in enumerate(hidden_dims):
# 1) linear transform
layers.append(nn.Linear(in_dim, out_dim)) layers.append(nn.Linear(in_dim, out_dim))
is_last = idx == total - 1 is_last = idx == total - 1
# 2-4) optionally add dropout, normalization, and activation
if not is_last or activate_final: if not is_last or activate_final:
if dropout_rate and dropout_rate > 0: if dropout_rate and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
@@ -805,7 +815,7 @@ class CriticEnsemble(nn.Module):
return q_values return q_values
class Policy(nn.Module): class NoiseActorPolicy(nn.Module):
"""Noise Actor (πW) that maps states to noise distributions. """Noise Actor (πW) that maps states to noise distributions.
This is the noise actor πW: S W that outputs noise vectors in the latent-noise space. This is the noise actor πW: S W that outputs noise vectors in the latent-noise space.