From 1594ae60a732669f50f3a63251a5b841233548cc Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 29 Oct 2025 15:22:27 +0100 Subject: [PATCH] * Change Diffusion policy to use chunk_size notation instead of horizon to standerize the variable names across policies * reshape noise after taking it as output of the network --- .../diffusion/configuration_diffusion.py | 16 ++--- .../policies/diffusion/modeling_diffusion.py | 20 +++--- src/lerobot/policies/dsrl/modeling_dsrl.py | 70 +++++++++++-------- 3 files changed, 58 insertions(+), 48 deletions(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 54569434a..14ae008cc 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig): Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. + chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents @@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig): # Inputs / output structure. n_obs_steps: int = 2 - horizon: int = 16 + chunk_size: int = 16 n_action_steps: int = 8 normalization_mapping: dict[str, NormalizationMode] = field( @@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig): # The original implementation doesn't sample frames for the last 7 steps, # which avoids excessive padding and leads to improved training results. - drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1 + drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1 # Architecture / modeling. # Vision backbone. @@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig): f"Got {self.noise_scheduler_type}." ) - # Check that the horizon size and U-Net downsampling is compatible. + # Check that the chunk size and U-Net downsampling is compatible. # U-Net downsamples by 2 with each stage. downsampling_factor = 2 ** len(self.down_dims) - if self.horizon % downsampling_factor != 0: + if self.chunk_size % downsampling_factor != 0: raise ValueError( - "The horizon should be an integer multiple of the downsampling factor (which is determined " - f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" + "The chunk_size should be an integer multiple of the downsampling factor (which is determined " + f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}" ) def get_optimizer_preset(self) -> AdamConfig: @@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig): @property def action_delta_indices(self) -> list: - return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size)) @property def reward_delta_indices(self) -> None: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 3ab6719cb..7cd2c12ee 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy): return actions @torch.no_grad() - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: """Select a single action given environment observations. This method handles caching a history of observations and an action trajectory generated by the underlying diffusion model. Here's how it works: - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is copied `n_obs_steps` times to fill the cache). - - The diffusion model generates `horizon` steps worth of actions. + - The diffusion model generates `chunk_size` steps worth of actions. - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. Schematically this looks like: ---------------------------------------------------------------------------------------------- - (legend: o = n_obs_steps, h = horizon, a = n_action_steps) + (legend: o = n_obs_steps, c = chunk_size, a = n_action_steps) |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h | |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO | |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | ---------------------------------------------------------------------------------------------- - Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that - "horizon" may not the best name to describe what the variable actually means, because this period is + Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that + this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out @@ -213,7 +213,7 @@ class DiffusionModel(nn.Module): noise if noise is not None else torch.randn( - size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), + size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]), dtype=dtype, device=device, generator=generator, @@ -309,16 +309,16 @@ class DiffusionModel(nn.Module): AND/OR "observation.environment_state": (B, n_obs_steps, environment_dim) - "action": (B, horizon, action_dim) - "action_is_pad": (B, horizon) + "action": (B, chunk_size, action_dim) + "action_is_pad": (B, chunk_size) } """ # Input validation. assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) assert OBS_IMAGES in batch or OBS_ENV_STATE in batch n_obs_steps = batch[OBS_STATE].shape[1] - horizon = batch[ACTION].shape[1] - assert horizon == self.config.horizon + chunk_size = batch[ACTION].shape[1] + assert chunk_size == self.config.chunk_size assert n_obs_steps == self.config.n_obs_steps # Encode image features and concatenate them all together along with the state vector. diff --git a/src/lerobot/policies/dsrl/modeling_dsrl.py b/src/lerobot/policies/dsrl/modeling_dsrl.py index c1c431947..e78d6996a 100644 --- a/src/lerobot/policies/dsrl/modeling_dsrl.py +++ b/src/lerobot/policies/dsrl/modeling_dsrl.py @@ -75,7 +75,7 @@ class DSRLPolicy(PreTrainedPolicy): def __init__( self, - config: DSRLConfig | None = None, + config: DSRLConfig, ): super().__init__(config) config.validate_features() @@ -91,11 +91,16 @@ class DSRLPolicy(PreTrainedPolicy): self._init_noise_actor() self._init_temperature() + # Set chunk size for action policy + if not hasattr(self.action_policy.config, "chunk_size"): + raise ValueError("Action policy config does not have a chunk size") + self.chunk_size = self.action_policy.config.chunk_size + def get_optim_params(self) -> dict: optim_params = { "noise_actor": [ p - for n, p in self.actor.named_parameters() + for n, p in self.noise_actor.named_parameters() if not n.startswith("encoder") or not self.shared_encoder ], "critic_action": self.action_critic_ensemble.parameters(), @@ -109,12 +114,12 @@ class DSRLPolicy(PreTrainedPolicy): pass @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: """Predict a chunk of actions given environment observations.""" raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!") @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]) -> Tensor: + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: """Select noise vector for inference/evaluation, pass it through the action policy to get the action. @@ -126,8 +131,9 @@ class DSRLPolicy(PreTrainedPolicy): observations_features = self.noise_actor.encoder.get_cached_image_features(batch) noise, _, _ = self.noise_actor(batch, observations_features) - - return self.action_policy(batch, noise) + noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1) + actions = self.action_policy.predict_action_chunk(batch, noise=noise) + return actions[:, 0, :] def action_critic_forward( self, @@ -203,7 +209,7 @@ class DSRLPolicy(PreTrainedPolicy): observation_features: Tensor | None = batch.get("observation_feature") if model == "critic_action": - # 1. Action Critic: TD-learning on action space + # Action Critic: TD-learning on action space # Extract critic-specific components actions: Tensor = batch[ACTION] rewards: Tensor = batch["reward"] @@ -223,7 +229,7 @@ class DSRLPolicy(PreTrainedPolicy): return {"loss_critic": loss_critic} if model == "critic_noise": - # 2. Noise Critic: Distillation from action critic + # Noise Critic: Distillation from action critic loss_critic_noise = self.compute_loss_critic_noise( observations=observations, observation_features=observation_features, @@ -231,7 +237,7 @@ class DSRLPolicy(PreTrainedPolicy): return {"loss_critic_noise": loss_critic_noise} if model == "noise_actor": - # 3. Noise Actor: Maximize Q-values in noise space + # Noise Actor: Maximize Q-values in noise space loss_noise_actor = self.compute_loss_noise_actor( observations=observations, observation_features=observation_features, @@ -283,14 +289,16 @@ class DSRLPolicy(PreTrainedPolicy): """ with torch.no_grad(): - # 1. Sample noise from noise actor: w' ~ πW(s') + # Sample noise from noise actor: w' ~ πW(s') next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features) + next_noise = next_noise.unsqueeze(1).repeat(1, self.chunk_size, 1) - # 2. Generate next actions + # Generate next actions # a' = πW_dp(s', w') - next_action_preds = self.action_policy(next_observations, next_noise) + next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise) + next_action_preds = next_actions_chunk[:, 0, :] - # 3. Compute target Q-values: Q̄A(s', a') + # Compute target Q-values: Q̄A(s', a') q_targets = self.action_critic_forward( observations=next_observations, actions=next_action_preds, @@ -312,7 +320,7 @@ class DSRLPolicy(PreTrainedPolicy): td_target = rewards + (1 - done) * self.config.discount * min_q - # 3- compute predicted qs + # compute predicted qs q_preds = self.action_critic_forward( observations=observations, actions=actions, @@ -320,7 +328,7 @@ class DSRLPolicy(PreTrainedPolicy): observation_features=observation_features, ) - # 4- Calculate loss + # Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up @@ -358,14 +366,15 @@ class DSRLPolicy(PreTrainedPolicy): batch_size = next(iter(observations.values())).shape[0] action_dim = self.config.output_features[ACTION].shape[0] - # 1. Sample noise w ~ N(0, I) + # Sample noise w ~ N(0, I) noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self)) - + noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1) with torch.no_grad(): - # 2. Generate action using base policy: a = πW_dp(s, w) - actions = self.action_policy(observations, noise) + # Generate action using base policy: a = πW_dp(s, w) + actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise) + actions = actions_chunk[:, 0, :] - # 3. Get target Q-values from action critic: QA(s, a) + # Get target Q-values from action critic: QA(s, a) q_targets = self.action_critic_forward( observations=observations, actions=actions, @@ -375,14 +384,14 @@ class DSRLPolicy(PreTrainedPolicy): # Average over ensemble critics q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size) - # 4. Get predicted Q-values from noise critic: QW(s, w) + # Get predicted Q-values from noise critic: QW(s, w) q_preds = self.noise_critic_forward( observations=observations, noise=noise, observation_features=observation_features, ) # (batch_size, 1) - # 5. Compute MSE loss + # Compute MSE loss loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0)) return loss @@ -417,17 +426,17 @@ class DSRLPolicy(PreTrainedPolicy): Returns: Noise actor loss with entropy regularization """ - # 1. Sample noise w ~ πW(s) from noise actor + # Sample noise w ~ πW(s) from noise actor noise, log_probs, _ = self.noise_actor(observations, observation_features) - # 2. Evaluate QW(s, w) using noise critic + # Evaluate QW(s, w) using noise critic q_values = self.noise_critic_forward( observations=observations, noise=noise, observation_features=observation_features, ) # (batch_size, 1) - # 3. Compute loss: minimize (temperature * log_prob - Q_value) + # Compute loss: minimize (temperature * log_prob - Q_value) # This is equivalent to maximizing (Q_value - temperature * log_prob) noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean() @@ -437,7 +446,10 @@ class DSRLPolicy(PreTrainedPolicy): """Initialize the action policy.""" action_policy = get_policy_class(self.config.action_policy_name) - self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights) + if self.config.action_policy_weights is not None: + self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights) + else: + self.action_policy = action_policy(self.config) self.action_policy.to(self.config.device) self.action_policy.eval() @@ -491,7 +503,7 @@ class DSRLPolicy(PreTrainedPolicy): def _init_noise_actor(self): """Initialize noise actor network and default target entropy.""" - self.noise_actor = Policy( + self.noise_actor = NoiseActorPolicy( encoder=self.encoder_noise_actor, network=MLP( input_dim=self.encoder_noise_actor.output_dim, @@ -704,11 +716,9 @@ class MLP(nn.Module): total = len(hidden_dims) for idx, out_dim in enumerate(hidden_dims): - # 1) linear transform layers.append(nn.Linear(in_dim, out_dim)) is_last = idx == total - 1 - # 2-4) optionally add dropout, normalization, and activation if not is_last or activate_final: if dropout_rate and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) @@ -805,7 +815,7 @@ class CriticEnsemble(nn.Module): return q_values -class Policy(nn.Module): +class NoiseActorPolicy(nn.Module): """Noise Actor (πW) that maps states to noise distributions. This is the noise actor πW: S → △W that outputs noise vectors in the latent-noise space.