mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
* Change Diffusion policy to use chunk_size notation instead of horizon to standerize the variable names across policies
* reshape noise after taking it as output of the network
This commit is contained in:
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 2
|
n_obs_steps: int = 2
|
||||||
horizon: int = 16
|
chunk_size: int = 16
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# The original implementation doesn't sample frames for the last 7 steps,
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
# which avoids excessive padding and leads to improved training results.
|
# which avoids excessive padding and leads to improved training results.
|
||||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
f"Got {self.noise_scheduler_type}."
|
f"Got {self.noise_scheduler_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the horizon size and U-Net downsampling is compatible.
|
# Check that the chunk size and U-Net downsampling is compatible.
|
||||||
# U-Net downsamples by 2 with each stage.
|
# U-Net downsamples by 2 with each stage.
|
||||||
downsampling_factor = 2 ** len(self.down_dims)
|
downsampling_factor = 2 ** len(self.down_dims)
|
||||||
if self.horizon % downsampling_factor != 0:
|
if self.chunk_size % downsampling_factor != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
|
||||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamConfig:
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
|
|||||||
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
This method handles caching a history of observations and an action trajectory generated by the
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
underlying diffusion model. Here's how it works:
|
underlying diffusion model. Here's how it works:
|
||||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||||
copied `n_obs_steps` times to fill the cache).
|
copied `n_obs_steps` times to fill the cache).
|
||||||
- The diffusion model generates `horizon` steps worth of actions.
|
- The diffusion model generates `chunk_size` steps worth of actions.
|
||||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
Schematically this looks like:
|
Schematically this looks like:
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|
||||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
|
|||||||
noise
|
noise
|
||||||
if noise is not None
|
if noise is not None
|
||||||
else torch.randn(
|
else torch.randn(
|
||||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
|
|||||||
AND/OR
|
AND/OR
|
||||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
|
|
||||||
"action": (B, horizon, action_dim)
|
"action": (B, chunk_size, action_dim)
|
||||||
"action_is_pad": (B, horizon)
|
"action_is_pad": (B, chunk_size)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Input validation.
|
# Input validation.
|
||||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||||
horizon = batch[ACTION].shape[1]
|
chunk_size = batch[ACTION].shape[1]
|
||||||
assert horizon == self.config.horizon
|
assert chunk_size == self.config.chunk_size
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
# Encode image features and concatenate them all together along with the state vector.
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DSRLConfig | None = None,
|
config: DSRLConfig,
|
||||||
):
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
@@ -91,11 +91,16 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
self._init_noise_actor()
|
self._init_noise_actor()
|
||||||
self._init_temperature()
|
self._init_temperature()
|
||||||
|
|
||||||
|
# Set chunk size for action policy
|
||||||
|
if not hasattr(self.action_policy.config, "chunk_size"):
|
||||||
|
raise ValueError("Action policy config does not have a chunk size")
|
||||||
|
self.chunk_size = self.action_policy.config.chunk_size
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
optim_params = {
|
optim_params = {
|
||||||
"noise_actor": [
|
"noise_actor": [
|
||||||
p
|
p
|
||||||
for n, p in self.actor.named_parameters()
|
for n, p in self.noise_actor.named_parameters()
|
||||||
if not n.startswith("encoder") or not self.shared_encoder
|
if not n.startswith("encoder") or not self.shared_encoder
|
||||||
],
|
],
|
||||||
"critic_action": self.action_critic_ensemble.parameters(),
|
"critic_action": self.action_critic_ensemble.parameters(),
|
||||||
@@ -109,12 +114,12 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!")
|
raise NotImplementedError("DSRLPolicy does not support action chunking. It returns single actions!")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||||
"""Select noise vector for inference/evaluation,
|
"""Select noise vector for inference/evaluation,
|
||||||
pass it through the action policy to get the action.
|
pass it through the action policy to get the action.
|
||||||
|
|
||||||
@@ -126,8 +131,9 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
observations_features = self.noise_actor.encoder.get_cached_image_features(batch)
|
observations_features = self.noise_actor.encoder.get_cached_image_features(batch)
|
||||||
|
|
||||||
noise, _, _ = self.noise_actor(batch, observations_features)
|
noise, _, _ = self.noise_actor(batch, observations_features)
|
||||||
|
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||||
return self.action_policy(batch, noise)
|
actions = self.action_policy.predict_action_chunk(batch, noise=noise)
|
||||||
|
return actions[:, 0, :]
|
||||||
|
|
||||||
def action_critic_forward(
|
def action_critic_forward(
|
||||||
self,
|
self,
|
||||||
@@ -203,7 +209,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
observation_features: Tensor | None = batch.get("observation_feature")
|
observation_features: Tensor | None = batch.get("observation_feature")
|
||||||
|
|
||||||
if model == "critic_action":
|
if model == "critic_action":
|
||||||
# 1. Action Critic: TD-learning on action space
|
# Action Critic: TD-learning on action space
|
||||||
# Extract critic-specific components
|
# Extract critic-specific components
|
||||||
actions: Tensor = batch[ACTION]
|
actions: Tensor = batch[ACTION]
|
||||||
rewards: Tensor = batch["reward"]
|
rewards: Tensor = batch["reward"]
|
||||||
@@ -223,7 +229,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
return {"loss_critic": loss_critic}
|
return {"loss_critic": loss_critic}
|
||||||
|
|
||||||
if model == "critic_noise":
|
if model == "critic_noise":
|
||||||
# 2. Noise Critic: Distillation from action critic
|
# Noise Critic: Distillation from action critic
|
||||||
loss_critic_noise = self.compute_loss_critic_noise(
|
loss_critic_noise = self.compute_loss_critic_noise(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
@@ -231,7 +237,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
return {"loss_critic_noise": loss_critic_noise}
|
return {"loss_critic_noise": loss_critic_noise}
|
||||||
|
|
||||||
if model == "noise_actor":
|
if model == "noise_actor":
|
||||||
# 3. Noise Actor: Maximize Q-values in noise space
|
# Noise Actor: Maximize Q-values in noise space
|
||||||
loss_noise_actor = self.compute_loss_noise_actor(
|
loss_noise_actor = self.compute_loss_noise_actor(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
@@ -283,14 +289,16 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# 1. Sample noise from noise actor: w' ~ πW(s')
|
# Sample noise from noise actor: w' ~ πW(s')
|
||||||
next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features)
|
next_noise, next_log_probs, _ = self.noise_actor(next_observations, next_observation_features)
|
||||||
|
next_noise = next_noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||||
|
|
||||||
# 2. Generate next actions
|
# Generate next actions
|
||||||
# a' = πW_dp(s', w')
|
# a' = πW_dp(s', w')
|
||||||
next_action_preds = self.action_policy(next_observations, next_noise)
|
next_actions_chunk = self.action_policy.predict_action_chunk(next_observations, next_noise)
|
||||||
|
next_action_preds = next_actions_chunk[:, 0, :]
|
||||||
|
|
||||||
# 3. Compute target Q-values: Q̄A(s', a')
|
# Compute target Q-values: Q̄A(s', a')
|
||||||
q_targets = self.action_critic_forward(
|
q_targets = self.action_critic_forward(
|
||||||
observations=next_observations,
|
observations=next_observations,
|
||||||
actions=next_action_preds,
|
actions=next_action_preds,
|
||||||
@@ -312,7 +320,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# compute predicted qs
|
||||||
q_preds = self.action_critic_forward(
|
q_preds = self.action_critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
@@ -320,7 +328,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||||
@@ -358,14 +366,15 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
batch_size = next(iter(observations.values())).shape[0]
|
batch_size = next(iter(observations.values())).shape[0]
|
||||||
action_dim = self.config.output_features[ACTION].shape[0]
|
action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|
||||||
# 1. Sample noise w ~ N(0, I)
|
# Sample noise w ~ N(0, I)
|
||||||
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
|
noise = torch.randn(batch_size, action_dim, device=get_device_from_parameters(self))
|
||||||
|
noise = noise.unsqueeze(1).repeat(1, self.chunk_size, 1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# 2. Generate action using base policy: a = πW_dp(s, w)
|
# Generate action using base policy: a = πW_dp(s, w)
|
||||||
actions = self.action_policy(observations, noise)
|
actions_chunk = self.action_policy.predict_action_chunk(observations, noise=noise)
|
||||||
|
actions = actions_chunk[:, 0, :]
|
||||||
|
|
||||||
# 3. Get target Q-values from action critic: QA(s, a)
|
# Get target Q-values from action critic: QA(s, a)
|
||||||
q_targets = self.action_critic_forward(
|
q_targets = self.action_critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
@@ -375,14 +384,14 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
# Average over ensemble critics
|
# Average over ensemble critics
|
||||||
q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size)
|
q_targets = q_targets.mean(dim=0, keepdim=True) # (1, batch_size)
|
||||||
|
|
||||||
# 4. Get predicted Q-values from noise critic: QW(s, w)
|
# Get predicted Q-values from noise critic: QW(s, w)
|
||||||
q_preds = self.noise_critic_forward(
|
q_preds = self.noise_critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
) # (batch_size, 1)
|
) # (batch_size, 1)
|
||||||
|
|
||||||
# 5. Compute MSE loss
|
# Compute MSE loss
|
||||||
loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0))
|
loss = F.mse_loss(q_preds.squeeze(-1), q_targets.squeeze(0))
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
@@ -417,17 +426,17 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Noise actor loss with entropy regularization
|
Noise actor loss with entropy regularization
|
||||||
"""
|
"""
|
||||||
# 1. Sample noise w ~ πW(s) from noise actor
|
# Sample noise w ~ πW(s) from noise actor
|
||||||
noise, log_probs, _ = self.noise_actor(observations, observation_features)
|
noise, log_probs, _ = self.noise_actor(observations, observation_features)
|
||||||
|
|
||||||
# 2. Evaluate QW(s, w) using noise critic
|
# Evaluate QW(s, w) using noise critic
|
||||||
q_values = self.noise_critic_forward(
|
q_values = self.noise_critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
) # (batch_size, 1)
|
) # (batch_size, 1)
|
||||||
|
|
||||||
# 3. Compute loss: minimize (temperature * log_prob - Q_value)
|
# Compute loss: minimize (temperature * log_prob - Q_value)
|
||||||
# This is equivalent to maximizing (Q_value - temperature * log_prob)
|
# This is equivalent to maximizing (Q_value - temperature * log_prob)
|
||||||
noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean()
|
noise_actor_loss = (self.temperature * log_probs - q_values.squeeze(-1)).mean()
|
||||||
|
|
||||||
@@ -437,7 +446,10 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
"""Initialize the action policy."""
|
"""Initialize the action policy."""
|
||||||
|
|
||||||
action_policy = get_policy_class(self.config.action_policy_name)
|
action_policy = get_policy_class(self.config.action_policy_name)
|
||||||
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
|
if self.config.action_policy_weights is not None:
|
||||||
|
self.action_policy = action_policy.from_pretrained(self.config.action_policy_weights)
|
||||||
|
else:
|
||||||
|
self.action_policy = action_policy(self.config)
|
||||||
self.action_policy.to(self.config.device)
|
self.action_policy.to(self.config.device)
|
||||||
self.action_policy.eval()
|
self.action_policy.eval()
|
||||||
|
|
||||||
@@ -491,7 +503,7 @@ class DSRLPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def _init_noise_actor(self):
|
def _init_noise_actor(self):
|
||||||
"""Initialize noise actor network and default target entropy."""
|
"""Initialize noise actor network and default target entropy."""
|
||||||
self.noise_actor = Policy(
|
self.noise_actor = NoiseActorPolicy(
|
||||||
encoder=self.encoder_noise_actor,
|
encoder=self.encoder_noise_actor,
|
||||||
network=MLP(
|
network=MLP(
|
||||||
input_dim=self.encoder_noise_actor.output_dim,
|
input_dim=self.encoder_noise_actor.output_dim,
|
||||||
@@ -704,11 +716,9 @@ class MLP(nn.Module):
|
|||||||
total = len(hidden_dims)
|
total = len(hidden_dims)
|
||||||
|
|
||||||
for idx, out_dim in enumerate(hidden_dims):
|
for idx, out_dim in enumerate(hidden_dims):
|
||||||
# 1) linear transform
|
|
||||||
layers.append(nn.Linear(in_dim, out_dim))
|
layers.append(nn.Linear(in_dim, out_dim))
|
||||||
|
|
||||||
is_last = idx == total - 1
|
is_last = idx == total - 1
|
||||||
# 2-4) optionally add dropout, normalization, and activation
|
|
||||||
if not is_last or activate_final:
|
if not is_last or activate_final:
|
||||||
if dropout_rate and dropout_rate > 0:
|
if dropout_rate and dropout_rate > 0:
|
||||||
layers.append(nn.Dropout(p=dropout_rate))
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
@@ -805,7 +815,7 @@ class CriticEnsemble(nn.Module):
|
|||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class NoiseActorPolicy(nn.Module):
|
||||||
"""Noise Actor (πW) that maps states to noise distributions.
|
"""Noise Actor (πW) that maps states to noise distributions.
|
||||||
|
|
||||||
This is the noise actor πW: S → △W that outputs noise vectors in the latent-noise space.
|
This is the noise actor πW: S → △W that outputs noise vectors in the latent-noise space.
|
||||||
|
|||||||
Reference in New Issue
Block a user