* Change Diffusion policy to use chunk_size notation instead of horizon to standerize the variable names across policies

* reshape noise after taking it as output of the network
This commit is contained in:
Michel Aractingi
2025-10-29 15:22:27 +01:00
parent 7cd710857d
commit 1594ae60a7
3 changed files with 58 additions and 48 deletions
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
chunk_size: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}."
)
# Check that the horizon size and U-Net downsampling is compatible.
# Check that the chunk size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
if self.chunk_size % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
@property
def reward_delta_indices(self) -> None:
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions.
- The diffusion model generates `chunk_size` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
noise
if noise is not None
else torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
AND/OR
"observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
"action": (B, chunk_size, action_dim)
"action_is_pad": (B, chunk_size)
}
"""
# Input validation.
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
chunk_size = batch[ACTION].shape[1]
assert chunk_size == self.config.chunk_size
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
+40 -30
View File
@@ -75,7 +75,7 @@ class DSRLPolicy(PreTrainedPolicy):
def __init__(
self,
config: DSRLConfig | None = None,
config: DSRLConfig,
):
super().__init__(config)
config.validate_features()
@@ -91,11 +91,16 @@ class DSRLPolicy(PreTrainedPolicy):
self._init_noise_actor()
self._init_temperature()
# Set chunk size for action policy
if not hasattr(self.action_policy.config, "chunk_size"):
raise ValueError("Action policy config does not have a chunk size")
self.chunk_size = self.action_policy.config.chunk_size
def get_optim_params(self) -> dict:
optim_params = {
"noise_actor": [
p
for n, p in self.actor.named_parameters()
for n, p in self.noise_actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic_action": self.action_critic_ensemble.parameters(),
@@ -109,12 +114,12 @@ class DSRLPolicy(PreTrainedPolicy):
pass
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!")
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Select noise vector for inference/evaluation,
pass it through the action policy to get the action.
@@ -126,8 +131,9 @@ class DSRLPolicy(PreTrainedPolicy):
observations_features = self.noise_actor.encoder.get_cached_image_features(batch)
noise, _, _ = self.noise_actor(batch, observations_features)
return self.action_policy(batch, noise)
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
actions = self.action_policy.predict_action_chunk(batch, noise=noise)
return actions[:, 0, :]
def action_critic_forward(
self,
@@ -203,7 +209,7 @@ class DSRLPolicy(PreTrainedPolicy):
observation_features: Tensor | None = batch.get("observation_feature")
if model == "critic_action":
# 1. Action Critic: TD-learning on action space
# Action Critic: TD-learning on action space
# Extract critic-specific components
actions: Tensor = batch[ACTION]
rewards: Tensor = batch["reward"]
@@ -223,7 +229,7 @@ class DSRLPolicy(PreTrainedPolicy):
return {"loss_critic": loss_critic}
if model == "critic_noise":
# 2. Noise Critic: Distillation from action critic
# Noise Critic: Distillation from action critic
loss_critic_noise = self.compute_loss_critic_noise(
observations=observations,
observation_features=observation_features,
@@ -231,7 +237,7 @@ class DSRLPolicy(PreTrainedPolicy):
return {"loss_critic_noise": loss_critic_noise}
if model == "noise_actor":
# 3. Noise Actor: Maximize Q-values in noise space
# Noise Actor: Maximize Q-values in noise space
loss_noise_actor = self.compute_loss_noise_actor(
observations=observations,
observation_features=observation_features,
@@ -283,14 +289,16 @@ class DSRLPolicy(PreTrainedPolicy):
"""
with torch.no_grad():
# 1. Sample noise from noise actor: w' ~ πW(s')
# Sample noise from noise actor: w' ~ πW(s')
next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features)
next_noise = next_noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
# 2. Generate next actions
# Generate next actions
# a' = πW_dp(s', w')
next_action_preds = self.action_policy(next_observations, next_noise)
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
next_action_preds = next_actions_chunk[:, 0, :]
# 3. Compute target Q-values: Q̄A(s', a')
# Compute target Q-values: Q̄A(s', a')
q_targets = self.action_critic_forward(
observations=next_observations,
actions=next_action_preds,
@@ -312,7 +320,7 @@ class DSRLPolicy(PreTrainedPolicy):
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
# compute predicted qs
q_preds = self.action_critic_forward(
observations=observations,
actions=actions,
@@ -320,7 +328,7 @@ class DSRLPolicy(PreTrainedPolicy):
observation_features=observation_features,
)
# 4- Calculate loss
# Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
@@ -358,14 +366,15 @@ class DSRLPolicy(PreTrainedPolicy):
batch_size = next(iter(observations.values())).shape[0]
action_dim = self.config.output_features[ACTION].shape[0]
# 1. Sample noise w ~ N(0, I)
# Sample noise w ~ N(0, I)
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
with torch.no_grad():
# 2. Generate action using base policy: a = πW_dp(s, w)
actions = self.action_policy(observations, noise)
# Generate action using base policy: a = πW_dp(s, w)
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
actions = actions_chunk[:, 0, :]
# 3. Get target Q-values from action critic: QA(s, a)
# Get target Q-values from action critic: QA(s, a)
q_targets = self.action_critic_forward(
observations=observations,
actions=actions,
@@ -375,14 +384,14 @@ class DSRLPolicy(PreTrainedPolicy):
# Average over ensemble critics
q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size)
# 4. Get predicted Q-values from noise critic: QW(s, w)
# Get predicted Q-values from noise critic: QW(s, w)
q_preds = self.noise_critic_forward(
observations=observations,
noise=noise,
observation_features=observation_features,
) # (batch_size, 1)
# 5. Compute MSE loss
# Compute MSE loss
loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0))
return loss
@@ -417,17 +426,17 @@ class DSRLPolicy(PreTrainedPolicy):
Returns:
Noise actor loss with entropy regularization
"""
# 1. Sample noise w ~ πW(s) from noise actor
# Sample noise w ~ πW(s) from noise actor
noise, log_probs, _ = self.noise_actor(observations, observation_features)
# 2. Evaluate QW(s, w) using noise critic
# Evaluate QW(s, w) using noise critic
q_values = self.noise_critic_forward(
observations=observations,
noise=noise,
observation_features=observation_features,
) # (batch_size, 1)
# 3. Compute loss: minimize (temperature * log_prob - Q_value)
# Compute loss: minimize (temperature * log_prob - Q_value)
# This is equivalent to maximizing (Q_value - temperature * log_prob)
noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean()
@@ -437,7 +446,10 @@ class DSRLPolicy(PreTrainedPolicy):
"""Initialize the action policy."""
action_policy = get_policy_class(self.config.action_policy_name)
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
if self.config.action_policy_weights is not None:
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
else:
self.action_policy = action_policy(self.config)
self.action_policy.to(self.config.device)
self.action_policy.eval()
@@ -491,7 +503,7 @@ class DSRLPolicy(PreTrainedPolicy):
def _init_noise_actor(self):
"""Initialize noise actor network and default target entropy."""
self.noise_actor = Policy(
self.noise_actor = NoiseActorPolicy(
encoder=self.encoder_noise_actor,
network=MLP(
input_dim=self.encoder_noise_actor.output_dim,
@@ -704,11 +716,9 @@ class MLP(nn.Module):
total = len(hidden_dims)
for idx, out_dim in enumerate(hidden_dims):
# 1) linear transform
layers.append(nn.Linear(in_dim, out_dim))
is_last = idx == total - 1
# 2-4) optionally add dropout, normalization, and activation
if not is_last or activate_final:
if dropout_rate and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
@@ -805,7 +815,7 @@ class CriticEnsemble(nn.Module):
return q_values
class Policy(nn.Module):
class NoiseActorPolicy(nn.Module):
"""Noise Actor (πW) that maps states to noise distributions.
This is the noise actor πW: S → △W that outputs noise vectors in the latent-noise space.