diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index d367a01ce..1b4bdff04 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -4,7 +4,6 @@ from pathlib import Path from queue import Empty, Full import torch -import torch.optim as optim from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -12,6 +11,7 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.reward_model.modeling_classifier import Classifier +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.gym_manipulator import make_robot_env from lerobot.robots.so_follower import SO100FollowerConfig @@ -40,8 +40,9 @@ def run_learner( policy_learner.train() policy_learner.to(device) - # Create Adam optimizer from scratch - simple and clean - optimizer = optim.Adam(policy_learner.parameters(), lr=lr) + algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config) + algorithm = SACAlgorithm(policy=policy_learner, config=algo_config) + algorithm.make_optimizers_and_scheduler() print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}") print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}") @@ -83,24 +84,26 @@ def run_learner( else: batch[key] = online_batch[key] - loss, _ = policy_learner.forward(batch) + def batch_iter(b=batch): + while True: + yield b - optimizer.zero_grad() - loss.backward() - optimizer.step() + stats = algorithm.update(batch_iter()) training_step += 1 if training_step % LOG_EVERY == 0: + log_dict = stats.to_log_dict() print( - f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, " + f"[LEARNER] Training step {training_step}, " + f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, " f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}" ) # Send updated parameters to actor every 10 training steps if training_step % SEND_EVERY == 0: try: - state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()} - parameters_queue.put_nowait(state_dict) + weights = algorithm.get_weights() + parameters_queue.put_nowait(weights) print("[LEARNER] Sent updated parameters to actor") except Full: # Missing write due to queue not being consumed (should happen rarely) @@ -144,15 +147,15 @@ def run_actor( while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set(): try: - new_params = parameters_queue.get_nowait() - policy_actor.load_state_dict(new_params) + new_weights = parameters_queue.get_nowait() + policy_actor.load_state_dict(new_weights) print("[ACTOR] Updated policy parameters from learner") except Empty: # No new updated parameters available from learner, waiting pass - # Get action from policy + # Get action from policy (returns full action: continuous + discrete) policy_obs = make_policy_obs(obs, device=device) - action_tensor = policy_actor.select_action(policy_obs) # predicts a single action + action_tensor = policy_actor.select_action(policy_obs) action = action_tensor.squeeze(0).cpu().numpy() # Step environment diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 8b8aedb26..ccc2070cc 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -214,3 +214,11 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig): # NOTE: In RL, we don't need an offline dataset # TODO: Make `TrainPipelineConfig.dataset` optional dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional + + # Algorithm name registered in RLAlgorithmConfig registry + algorithm: str = "sac" + + # Data mixer strategy name. Currently supports "online_offline" + mixer: str = "online_offline" + # Fraction sampled from online replay when using OnlineOfflineMixer + online_ratio: float = 0.5 diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index d5dd71a48..a9659a5ba 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -15,16 +15,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from collections.abc import Callable from dataclasses import asdict -from typing import Literal +from typing import Any -import einops import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution @@ -39,6 +36,8 @@ DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension class SACPolicy( PreTrainedPolicy, ): + """SAC policy.""" + config_class = SACConfig name = "sac" @@ -53,9 +52,8 @@ class SACPolicy( # Determine action dimension and initialize all components continuous_action_dim = config.output_features[ACTION].shape[0] self._init_encoders() - self._init_critics(continuous_action_dim) self._init_actor(continuous_action_dim) - self._init_temperature() + self.discrete_critic = None def get_optim_params(self) -> dict: optim_params = { @@ -64,11 +62,7 @@ class SACPolicy( for n, p in self.actor.named_parameters() if not n.startswith("encoder") or not self.shared_encoder ], - "critic": self.critic_ensemble.parameters(), - "temperature": self.log_alpha, } - if self.config.num_discrete_actions is not None: - optim_params["discrete_critic"] = self.discrete_critic.parameters() return optim_params def reset(self): @@ -91,304 +85,49 @@ class SACPolicy( actions, _, _ = self.actor(batch, observations_features) if self.config.num_discrete_actions is not None: - discrete_action_value = self.discrete_critic(batch, observations_features) - discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + if self.discrete_critic is not None: + discrete_action_value = self.discrete_critic(batch, observations_features) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + else: + discrete_action = torch.ones( + (*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype + ) actions = torch.cat([actions, discrete_action], dim=-1) return actions - def critic_forward( - self, - observations: dict[str, Tensor], - actions: Tensor, - use_target: bool = False, - observation_features: Tensor | None = None, - ) -> Tensor: - """Forward pass through a critic network ensemble + def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]: + """Actor forward pass: sample actions and return log-probabilities. Args: - observations: Dictionary of observations - actions: Action tensor - use_target: If True, use target critics, otherwise use ensemble critics + batch: A flat observation dict, or a training dict containing + ``"state"`` (observations) and optionally ``"observation_feature"`` + (pre-computed encoder features). Returns: - Tensor of Q-values from all critics + Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors. """ + observations = batch.get("state", batch) + observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None + actions, log_probs, means = self.actor(observations, observation_features) + return {"action": actions, "log_prob": log_probs, "action_mean": means} - critics = self.critic_target if use_target else self.critic_ensemble - q_values = critics(observations, actions, observation_features) - return q_values + def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None: + from lerobot.utils.transition import move_state_dict_to_device - def discrete_critic_forward( - self, observations, use_target=False, observation_features=None - ) -> torch.Tensor: - """Forward pass through a discrete critic network + actor_sd = move_state_dict_to_device(state_dicts["policy"], device=device) + self.actor.load_state_dict(actor_sd) - Args: - observations: Dictionary of observations - use_target: If True, use target critics, otherwise use ensemble critics - observation_features: Optional pre-computed observation features to avoid recomputing encoder output - - Returns: - Tensor of Q-values from the discrete critic network - """ - discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic - q_values = discrete_critic(observations, observation_features) - return q_values - - def forward( - self, - batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", - ) -> dict[str, Tensor]: - """Compute the loss for the given model - - Args: - batch: Dictionary containing: - - action: Action tensor - - reward: Reward tensor - - state: Observations tensor dict - - next_state: Next observations tensor dict - - done: Done mask tensor - - observation_feature: Optional pre-computed observation features - - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") - - Returns: - The computed loss tensor - """ - # Extract common components from batch - actions: Tensor = batch[ACTION] - observations: dict[str, Tensor] = batch["state"] - observation_features: Tensor = batch.get("observation_feature") - - if model == "critic": - # Extract critic-specific components - rewards: Tensor = batch["reward"] - next_observations: dict[str, Tensor] = batch["next_state"] - done: Tensor = batch["done"] - next_observation_features: Tensor = batch.get("next_observation_feature") - - loss_critic = self.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - - return {"loss_critic": loss_critic} - - if model == "discrete_critic" and self.config.num_discrete_actions is not None: - # Extract critic-specific components - rewards: Tensor = batch["reward"] - next_observations: dict[str, Tensor] = batch["next_state"] - done: Tensor = batch["done"] - next_observation_features: Tensor = batch.get("next_observation_feature") - complementary_info = batch.get("complementary_info") - loss_discrete_critic = self.compute_loss_discrete_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - complementary_info=complementary_info, - ) - return {"loss_discrete_critic": loss_discrete_critic} - if model == "actor": - return { - "loss_actor": self.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) - } - - if model == "temperature": - return { - "loss_temperature": self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) - } - - raise ValueError(f"Unknown model type: {model}") - - def update_target_networks(self): - """Update target networks with exponential moving average""" - for target_param, param in zip( - self.critic_target.parameters(), - self.critic_ensemble.parameters(), - strict=True, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) - if self.config.num_discrete_actions is not None: - for target_param, param in zip( - self.discrete_critic_target.parameters(), - self.discrete_critic.parameters(), - strict=True, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) - - @property - def temperature(self) -> float: - """Return the current temperature value, always in sync with log_alpha.""" - return self.log_alpha.exp().item() - - def compute_loss_critic( - self, - observations, - actions, - rewards, - next_observations, - done, - observation_features: Tensor | None = None, - next_observation_features: Tensor | None = None, - ) -> Tensor: - with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) - - # 2- compute q targets - q_targets = self.critic_forward( - observations=next_observations, - actions=next_action_preds, - use_target=True, - observation_features=next_observation_features, - ) - - # subsample critics to prevent overfitting if use high UTD (update to date) - # TODO: Get indices before forward pass to avoid unnecessary computation - if self.config.num_subsample_critics is not None: - indices = torch.randperm(self.config.num_critics) - indices = indices[: self.config.num_subsample_critics] - q_targets = q_targets[indices] - - # critics subsample size - min_q, _ = q_targets.min(dim=0) # Get values from min operation - if self.config.use_backup_entropy: - min_q = min_q - (self.temperature * next_log_probs) - - td_target = rewards + (1 - done) * self.config.discount * min_q - - # 3- compute predicted qs - if self.config.num_discrete_actions is not None: - # NOTE: We only want to keep the continuous action part - # In the buffer we have the full action space (continuous + discrete) - # We need to split them before concatenating them in the critic forward - actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - q_preds = self.critic_forward( - observations=observations, - actions=actions, - use_target=False, - observation_features=observation_features, - ) - - # 4- Calculate loss - # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) - # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up - critics_loss = ( - F.mse_loss( - input=q_preds, - target=td_target_duplicate, - reduction="none", - ).mean(dim=1) - ).sum() - return critics_loss - - def compute_loss_discrete_critic( - self, - observations, - actions, - rewards, - next_observations, - done, - observation_features=None, - next_observation_features=None, - complementary_info=None, - ): - # NOTE: We only want to keep the discrete action part - # In the buffer we have the full action space (continuous + discrete) - # We need to split them before concatenating them in the critic forward - actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() - actions_discrete = torch.round(actions_discrete) - actions_discrete = actions_discrete.long() - - discrete_penalties: Tensor | None = None - if complementary_info is not None: - discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") - - with torch.no_grad(): - # For DQN, select actions using online network, evaluate with target network - next_discrete_qs = self.discrete_critic_forward( - next_observations, use_target=False, observation_features=next_observation_features - ) - best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) - - # Get target Q-values from target network - target_next_discrete_qs = self.discrete_critic_forward( - observations=next_observations, - use_target=True, - observation_features=next_observation_features, - ) - - # Use gather to select Q-values for best actions - target_next_discrete_q = torch.gather( - target_next_discrete_qs, dim=1, index=best_next_discrete_action - ).squeeze(-1) - - # Compute target Q-value with Bellman equation - rewards_discrete = rewards - if discrete_penalties is not None: - rewards_discrete = rewards + discrete_penalties - target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q - - # Get predicted Q-values for current observations - predicted_discrete_qs = self.discrete_critic_forward( - observations=observations, use_target=False, observation_features=observation_features - ) - - # Use gather to select Q-values for taken actions - predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) - - # Compute MSE loss between predicted and target Q-values - discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) - return discrete_critic_loss - - def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: - """Compute the temperature loss""" - # calculate temperature loss - with torch.no_grad(): - _, log_probs, _ = self.actor(observations, observation_features) - temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() - return temperature_loss - - def compute_loss_actor( - self, - observations, - observation_features: Tensor | None = None, - ) -> Tensor: - actions_pi, log_probs, _ = self.actor(observations, observation_features) - - q_preds = self.critic_forward( - observations=observations, - actions=actions_pi, - use_target=False, - observation_features=observation_features, - ) - min_q_preds = q_preds.min(dim=0)[0] - - actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() - return actor_loss + if "discrete_critic" in state_dicts: + dc_sd = move_state_dict_to_device(state_dicts["discrete_critic"], device=device) + if self.discrete_critic is None: + self.discrete_critic = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ).to(device) + self.discrete_critic.load_state_dict(dc_sd) def _init_encoders(self): """Initialize shared or separate encoders for actor and critic.""" @@ -398,51 +137,6 @@ class SACPolicy( self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config) ) - def _init_critics(self, continuous_action_dim): - """Build critic ensemble, targets, and optional discrete critic.""" - heads = [ - CriticHead( - input_dim=self.encoder_critic.output_dim + continuous_action_dim, - **asdict(self.config.critic_network_kwargs), - ) - for _ in range(self.config.num_critics) - ] - self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) - target_heads = [ - CriticHead( - input_dim=self.encoder_critic.output_dim + continuous_action_dim, - **asdict(self.config.critic_network_kwargs), - ) - for _ in range(self.config.num_critics) - ] - self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) - self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) - - if self.config.use_torch_compile: - self.critic_ensemble = torch.compile(self.critic_ensemble) - self.critic_target = torch.compile(self.critic_target) - - if self.config.num_discrete_actions is not None: - self._init_discrete_critics() - - def _init_discrete_critics(self): - """Build discrete discrete critic ensemble and target networks.""" - self.discrete_critic = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) - self.discrete_critic_target = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) - - # TODO: (maractingi, azouitine) Compile the discrete critic - self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) - def _init_actor(self, continuous_action_dim): """Initialize policy actor network and default target entropy.""" # NOTE: The actor select only the continuous action part @@ -459,11 +153,6 @@ class SACPolicy( dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) self.target_entropy = -np.prod(dim) / 2 - def _init_temperature(self) -> None: - """Set up temperature parameter (log_alpha).""" - temp_init = self.config.temperature_init - self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) - class SACObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" @@ -676,84 +365,6 @@ class MLP(nn.Module): return self.net(x) -class CriticHead(nn.Module): - def __init__( - self, - input_dim: int, - hidden_dims: list[int], - activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), - activate_final: bool = False, - dropout_rate: float | None = None, - init_final: float | None = None, - final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, - ): - super().__init__() - self.net = MLP( - input_dim=input_dim, - hidden_dims=hidden_dims, - activations=activations, - activate_final=activate_final, - dropout_rate=dropout_rate, - final_activation=final_activation, - ) - self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) - if init_final is not None: - nn.init.uniform_(self.output_layer.weight, -init_final, init_final) - nn.init.uniform_(self.output_layer.bias, -init_final, init_final) - else: - orthogonal_init()(self.output_layer.weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.output_layer(self.net(x)) - - -class CriticEnsemble(nn.Module): - """ - CriticEnsemble wraps multiple CriticHead modules into an ensemble. - - Args: - encoder (SACObservationEncoder): encoder for observations. - ensemble (List[CriticHead]): list of critic heads. - init_final (float | None): optional initializer scale for final layers. - - Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. - """ - - def __init__( - self, - encoder: SACObservationEncoder, - ensemble: list[CriticHead], - init_final: float | None = None, - ): - super().__init__() - self.encoder = encoder - self.init_final = init_final - self.critics = nn.ModuleList(ensemble) - - def forward( - self, - observations: dict[str, torch.Tensor], - actions: torch.Tensor, - observation_features: torch.Tensor | None = None, - ) -> torch.Tensor: - device = get_device_from_parameters(self) - # Move each tensor in observations to device - observations = {k: v.to(device) for k, v in observations.items()} - - obs_enc = self.encoder(observations, cache=observation_features) - - inputs = torch.cat([obs_enc, actions], dim=-1) - - # Loop through critics and collect outputs - q_values = [] - for critic in self.critics: - q_values.append(critic(inputs)) - - # Stack outputs to match expected shape [num_critics, batch_size] - q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) - return q_values - - class DiscreteCritic(nn.Module): def __init__( self, diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index 879e3c1af..9b76b8037 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig): latent_dim: int = 256 image_embedding_pooling_dim: int = 8 dropout_rate: float = 0.1 - model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading. + model_name: str = "helper2424/resnet10" device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 8a7a1176a..7ded1b568 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -131,6 +131,15 @@ class _NormalizationMixin: if self.dtype is None: self.dtype = torch.float32 self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() + + def _reshape_visual_stats(self) -> None: + """Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting.""" + for key, feature in self.features.items(): + if feature.type == FeatureType.VISUAL and key in self._tensor_stats: + for stat_name, stat_tensor in self._tensor_stats[key].items(): + if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1: + self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) def to( self, device: torch.device | str | None = None, dtype: torch.dtype | None = None @@ -149,6 +158,7 @@ class _NormalizationMixin: if dtype is not None: self.dtype = dtype self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() return self def state_dict(self) -> dict[str, Tensor]: @@ -198,6 +208,7 @@ class _NormalizationMixin: # Don't load from state_dict, keep the explicitly provided stats # But ensure _tensor_stats is properly initialized self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + self._reshape_visual_stats() return # Normal behavior: load stats from state_dict @@ -208,6 +219,7 @@ class _NormalizationMixin: self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( dtype=torch.float32, device=self.device ) + self._reshape_visual_stats() # Reconstruct the original stats dict from tensor stats for compatibility with to() method # and other functions that rely on self.stats diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py new file mode 100644 index 000000000..19b2f1409 --- /dev/null +++ b/src/lerobot/rl/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 18c0ca1ea..6d73891b6 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -60,8 +60,9 @@ from torch.multiprocessing import Event, Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies.factory import make_policy -from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 @@ -76,13 +77,11 @@ from lerobot.transport.utils import ( send_bytes_in_chunks, transitions_to_bytes, ) -from lerobot.types import TransitionKey from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( Transition, - move_state_dict_to_device, move_transition_to_device, ) from lerobot.utils.utils import ( @@ -251,13 +250,18 @@ def act_with_policy( ### Instantiate the policy in both the actor and learner processes ### To avoid sending a SACPolicy object through the port, we create a policy instance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters - policy: SACPolicy = make_policy( + policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) - policy = policy.eval() + policy = policy.to(device).eval() assert isinstance(policy, nn.Module) + preprocessor, _postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + dataset_stats=cfg.policy.dataset_stats, + ) + obs, info = online_env.reset() env_processor.reset() action_processor.reset() @@ -288,8 +292,8 @@ def act_with_policy( # Time policy inference and check if it meets FPS requirement with policy_timer: - # Extract observation from transition for policy - action = policy.select_action(batch=observation) + normalized_observation = preprocessor.process_observation(observation) + action = policy.select_action(batch=normalized_observation) policy_fps = policy_timer.fps_last log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) @@ -649,7 +653,7 @@ def interactions_stream( # Policy functions -def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): +def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device): bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") @@ -664,18 +668,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) # - Send critic's encoder state when shared_encoder=True # - Skip encoder params entirely when freeze_vision_encoder=True # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) - - # Load actor state dict - actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) - policy.actor.load_state_dict(actor_state_dict) - - # Load discrete critic if present - if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts: - discrete_critic_state_dict = move_state_dict_to_device( - state_dicts["discrete_critic"], device=device - ) - policy.discrete_critic.load_state_dict(discrete_critic_state_dict) - logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") + policy.load_actor_weights(state_dicts, device=device) # Utilities functions diff --git a/src/lerobot/rl/algorithms/__init__.py b/src/lerobot/rl/algorithms/__init__.py new file mode 100644 index 000000000..fe4a51846 --- /dev/null +++ b/src/lerobot/rl/algorithms/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.rl.algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig + +__all__ = [ + "SACAlgorithm", + "SACAlgorithmConfig", +] diff --git a/src/lerobot/rl/algorithms/base.py b/src/lerobot/rl/algorithms/base.py new file mode 100644 index 000000000..b9f2c908c --- /dev/null +++ b/src/lerobot/rl/algorithms/base.py @@ -0,0 +1,106 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any + +import torch +from torch.optim import Optimizer + +from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats + +if TYPE_CHECKING: + from lerobot.rl.data_sources.data_mixer import DataMixer + +BatchType = dict[str, Any] + + +class RLAlgorithm(abc.ABC): + """Base for all RL algorithms.""" + + config_class: type[RLAlgorithmConfig] | None = None + name: str | None = None + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not getattr(cls, "config_class", None): + raise TypeError(f"Class {cls.__name__} must define 'config_class'") + if not getattr(cls, "name", None): + raise TypeError(f"Class {cls.__name__} must define 'name'") + + @abc.abstractmethod + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """One complete training step. + + The algorithm calls ``next(batch_iterator)`` as many times as it + needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches. + The iterator is owned by the trainer; the algorithm just consumes + from it. + """ + ... + + def configure_data_iterator( + self, + data_mixer: DataMixer, + batch_size: int, + *, + async_prefetch: bool = True, + queue_size: int = 2, + ) -> Iterator[BatchType]: + """Create the data iterator this algorithm needs. + + The default implementation uses the standard ``data_mixer.get_iterator()``. + Algorithms that need specialised sampling should override this method. + """ + return data_mixer.get_iterator( + batch_size=batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]: + """Create, store, and return the optimizers needed for training. + + Called on the **learner** side after construction. Subclasses must + override this with algorithm-specific optimizer setup. + """ + return {} + + def get_optimizers(self) -> dict[str, Optimizer]: + """Return optimizers for checkpointing / external scheduling.""" + return {} + + @property + def optimization_step(self) -> int: + """Current learner optimization step. + + Part of the stable contract for checkpoint/resume. Algorithms can + either use this default storage or override for custom behavior. + """ + return getattr(self, "_optimization_step", 0) + + @optimization_step.setter + def optimization_step(self, value: int) -> None: + self._optimization_step = int(value) + + def get_weights(self) -> dict[str, Any]: + """Policy state-dict to push to actors.""" + return {} + + @abc.abstractmethod + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + """Load policy state-dict received from the learner.""" diff --git a/src/lerobot/rl/algorithms/configs.py b/src/lerobot/rl/algorithms/configs.py new file mode 100644 index 000000000..421f7ae09 --- /dev/null +++ b/src/lerobot/rl/algorithms/configs.py @@ -0,0 +1,65 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import draccus +import torch + +if TYPE_CHECKING: + from lerobot.rl.algorithms.base import RLAlgorithm + + +@dataclass +class TrainingStats: + """Returned by ``algorithm.update()`` for logging and checkpointing.""" + + losses: dict[str, float] = field(default_factory=dict) + grad_norms: dict[str, float] = field(default_factory=dict) + extra: dict[str, float] = field(default_factory=dict) + + def to_log_dict(self) -> dict[str, float]: + """Flatten all stats into a single dict for logging.""" + + d: dict[str, float] = {} + for name, val in self.losses.items(): + d[name] = val + for name, val in self.grad_norms.items(): + d[f"{name}_grad_norm"] = val + for name, val in self.extra.items(): + d[name] = val + return d + + +@dataclass +class RLAlgorithmConfig(draccus.ChoiceRegistry): + """Registry for algorithm configs.""" + + def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm: + """Construct the :class:`RLAlgorithm` for this config. + + Must be overridden by every registered config subclass. + """ + raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()") + + @classmethod + def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig: + """Build an algorithm config from a policy config. + + Must be overridden by every registered config subclass. + """ + raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()") diff --git a/src/lerobot/rl/algorithms/factory.py b/src/lerobot/rl/algorithms/factory.py new file mode 100644 index 000000000..70622c5f2 --- /dev/null +++ b/src/lerobot/rl/algorithms/factory.py @@ -0,0 +1,35 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from lerobot.rl.algorithms.base import RLAlgorithm +from lerobot.rl.algorithms.configs import RLAlgorithmConfig + + +def make_algorithm( + policy: torch.nn.Module, + policy_cfg, + *, + algorithm_name: str, +) -> RLAlgorithm: + known = RLAlgorithmConfig.get_known_choices() + if algorithm_name not in known: + raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}") + + config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name) + algo_config = config_cls.from_policy_config(policy_cfg) + return algo_config.build_algorithm(policy) diff --git a/src/lerobot/rl/algorithms/sac/__init__.py b/src/lerobot/rl/algorithms/sac/__init__.py new file mode 100644 index 000000000..6e0e76a66 --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig +from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm + +__all__ = ["SACAlgorithm", "SACAlgorithmConfig"] diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py new file mode 100644 index 000000000..c2aac050b --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -0,0 +1,80 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +from lerobot.policies.sac.configuration_sac import CriticNetworkConfig +from lerobot.rl.algorithms.configs import RLAlgorithmConfig + +if TYPE_CHECKING: + from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm + + +@RLAlgorithmConfig.register_subclass("sac") +@dataclass +class SACAlgorithmConfig(RLAlgorithmConfig): + """SAC-specific hyper-parameters that control the update loop.""" + + utd_ratio: int = 1 + policy_update_freq: int = 1 + clip_grad_norm: float = 40.0 + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + temperature_lr: float = 3e-4 + discount: float = 0.99 + temperature_init: float = 1.0 + target_entropy: float | None = None + use_backup_entropy: bool = True + critic_target_update_weight: float = 0.005 + num_critics: int = 2 + num_subsample_critics: int | None = None + num_discrete_actions: int | None = None + shared_encoder: bool = True + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + use_torch_compile: bool = True + + @classmethod + def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: + """Build from an existing ``SACConfig`` (cfg.policy) for backwards compat.""" + return cls( + utd_ratio=policy_cfg.utd_ratio, + policy_update_freq=policy_cfg.policy_update_freq, + clip_grad_norm=policy_cfg.grad_clip_norm, + actor_lr=policy_cfg.actor_lr, + critic_lr=policy_cfg.critic_lr, + temperature_lr=policy_cfg.temperature_lr, + discount=policy_cfg.discount, + temperature_init=policy_cfg.temperature_init, + target_entropy=policy_cfg.target_entropy, + use_backup_entropy=policy_cfg.use_backup_entropy, + critic_target_update_weight=policy_cfg.critic_target_update_weight, + num_critics=policy_cfg.num_critics, + num_subsample_critics=policy_cfg.num_subsample_critics, + num_discrete_actions=policy_cfg.num_discrete_actions, + shared_encoder=policy_cfg.shared_encoder, + critic_network_kwargs=policy_cfg.critic_network_kwargs, + discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, + use_torch_compile=policy_cfg.use_torch_compile, + ) + + def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: + from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm + + return SACAlgorithm(policy=policy, config=self) diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py new file mode 100644 index 000000000..0ce7a6875 --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -0,0 +1,600 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Callable, Iterator +from dataclasses import asdict +from typing import Any + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.optim import Optimizer + +from lerobot.policies.sac.modeling_sac import ( + DISCRETE_DIMENSION_INDEX, + MLP, + DiscreteCritic, + SACObservationEncoder, + SACPolicy, + orthogonal_init, +) +from lerobot.policies.utils import get_device_from_parameters +from lerobot.rl.algorithms.base import BatchType, RLAlgorithm +from lerobot.rl.algorithms.configs import TrainingStats +from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig +from lerobot.utils.constants import ACTION +from lerobot.utils.transition import move_state_dict_to_device + + +class SACAlgorithm(RLAlgorithm): + """Soft Actor-Critic. Owns critics, targets, temperature, and loss computation.""" + + config_class = SACAlgorithmConfig + name = "sac" + + def __init__( + self, + policy: SACPolicy, + config: SACAlgorithmConfig, + ): + self.policy = policy + self.config = config + self.optimizers: dict[str, Optimizer] = {} + self._optimization_step: int = 0 + + self._init_critics() + self._init_temperature() + + self._device = torch.device(self.policy.config.device) + self._move_to_device() + + def _init_critics(self) -> None: + """Build critic ensemble, targets.""" + encoder = self.policy.encoder_critic + action_dim = self.policy.config.output_features[ACTION].shape[0] + + heads = [ + CriticHead( + input_dim=encoder.output_dim + action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) + target_heads = [ + CriticHead( + input_dim=encoder.output_dim + action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + self.discrete_critic = None + self.discrete_critic_target = None + if self.config.num_discrete_actions is not None: + self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder) + self.policy.discrete_critic = self.discrete_critic + + def _init_discrete_critics(self, encoder: SACObservationEncoder) -> tuple[DiscreteCritic, DiscreteCritic]: + """Build discrete discrete critic ensemble and target networks.""" + discrete_critic = DiscreteCritic( + encoder=encoder, + input_dim=encoder.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + discrete_critic_target = DiscreteCritic( + encoder=encoder, + input_dim=encoder.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + + # TODO: (maractingi, azouitine) Compile the discrete critic + discrete_critic_target.load_state_dict(discrete_critic.state_dict()) + return discrete_critic, discrete_critic_target + + def _init_temperature(self) -> None: + """Set up temperature parameter (log_alpha).""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + + def _move_to_device(self) -> None: + self.policy.to(self._device) + self.critic_ensemble.to(self._device) + self.critic_target.to(self._device) + self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device)) + if self.discrete_critic is not None: + self.discrete_critic.to(self._device) + self.discrete_critic_target.to(self._device) + + @property + def temperature(self) -> float: + """Return the current temperature value, always in sync with log_alpha.""" + return self.log_alpha.exp().item() + + def _critic_forward( + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, + ) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + + critics = self.critic_target if use_target else self.critic_ensemble + q_values = critics(observations, actions, observation_features) + return q_values + + def _discrete_critic_forward( + self, observations, use_target=False, observation_features=None + ) -> torch.Tensor: + """Forward pass through a discrete critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the discrete critic network + """ + discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic + q_values = discrete_critic(observations, observation_features) + return q_values + + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + clip = self.config.clip_grad_norm + + for _ in range(self.config.utd_ratio - 1): + batch = next(batch_iterator) + fb = self._prepare_forward_batch(batch, include_complementary_info=True) + + loss_critic = self._compute_loss_critic(fb) + self.optimizers["critic"].zero_grad() + loss_critic.backward() + torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip) + self.optimizers["critic"].step() + + if self.config.num_discrete_actions is not None: + loss_dc = self._compute_loss_discrete_critic(fb) + self.optimizers["discrete_critic"].zero_grad() + loss_dc.backward() + torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip) + self.optimizers["discrete_critic"].step() + + self._update_target_networks() + + batch = next(batch_iterator) + fb = self._prepare_forward_batch(batch, include_complementary_info=False) + + loss_critic = self._compute_loss_critic(fb) + self.optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item() + self.optimizers["critic"].step() + + stats = TrainingStats( + losses={"loss_critic": loss_critic.item()}, + grad_norms={"critic": critic_grad}, + ) + + if self.config.num_discrete_actions is not None: + loss_dc = self._compute_loss_discrete_critic(fb) + self.optimizers["discrete_critic"].zero_grad() + loss_dc.backward() + dc_grad = torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip).item() + self.optimizers["discrete_critic"].step() + stats.losses["loss_discrete_critic"] = loss_dc.item() + stats.grad_norms["discrete_critic"] = dc_grad + + if self._optimization_step % self.config.policy_update_freq == 0: + for _ in range(self.config.policy_update_freq): + loss_actor = self._compute_loss_actor(fb) + self.optimizers["actor"].zero_grad() + loss_actor.backward() + actor_grad = torch.nn.utils.clip_grad_norm_( + self.policy.actor.parameters(), max_norm=clip + ).item() + self.optimizers["actor"].step() + + loss_temp = self._compute_loss_temperature(fb) + self.optimizers["temperature"].zero_grad() + loss_temp.backward() + temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item() + self.optimizers["temperature"].step() + + stats.losses["loss_actor"] = loss_actor.item() + stats.losses["loss_temperature"] = loss_temp.item() + stats.grad_norms["actor"] = actor_grad + stats.grad_norms["temperature"] = temp_grad + stats.extra["temperature"] = self.temperature + + self._update_target_networks() + self._optimization_step += 1 + return stats + + def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + actions = batch[ACTION] + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + observation_features = batch.get("observation_feature") + next_observation_features = batch.get("next_observation_feature") + + with torch.no_grad(): + next_action_preds, next_log_probs, _ = self.policy.actor( + next_observations, next_observation_features + ) + + # 2- compute q targets + q_targets = self._critic_forward( + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, + ) + + # subsample critics to prevent overfitting if use high UTD (update to date) + # TODO: Get indices before forward pass to avoid unnecessary computation + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + if self.config.use_backup_entropy: + min_q = min_q - (self.temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q + + # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self._critic_forward( + observations=observations, + actions=actions, + use_target=False, + observation_features=observation_features, + ) + + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up + critics_loss = ( + F.mse_loss( + input=q_preds, + target=td_target_duplicate, + reduction="none", + ).mean(dim=1) + ).sum() + return critics_loss + + def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + actions = batch[ACTION] + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + observation_features = batch.get("observation_feature") + next_observation_features = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) + actions_discrete = actions_discrete.long() + + discrete_penalties: Tensor | None = None + if complementary_info is not None: + discrete_penalties = complementary_info.get("discrete_penalty") + + with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network + next_discrete_qs = self._discrete_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) + + # Get target Q-values from target network + target_next_discrete_qs = self._discrete_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) + + # Use gather to select Q-values for best actions + target_next_discrete_q = torch.gather( + target_next_discrete_qs, dim=1, index=best_next_discrete_action + ).squeeze(-1) + + # Compute target Q-value with Bellman equation + rewards_discrete = rewards + if discrete_penalties is not None: + rewards_discrete = rewards + discrete_penalties + target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q + + # Get predicted Q-values for current observations + predicted_discrete_qs = self._discrete_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) + + # Use gather to select Q-values for taken actions + predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) + + # Compute MSE loss between predicted and target Q-values + discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) + return discrete_critic_loss + + def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + observation_features = batch.get("observation_feature") + + actions_pi, log_probs, _ = self.policy.actor(observations, observation_features) + + q_preds = self._critic_forward( + observations=observations, + actions=actions_pi, + use_target=False, + observation_features=observation_features, + ) + min_q_preds = q_preds.min(dim=0)[0] + + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() + return actor_loss + + def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor: + """Compute the temperature loss""" + observations = batch["state"] + observation_features = batch.get("observation_feature") + + # calculate temperature loss + with torch.no_grad(): + _, log_probs, _ = self.policy.actor(observations, observation_features) + + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.policy.target_entropy)).mean() + return temperature_loss + + def _update_target_networks(self) -> None: + """Update target networks with exponential moving average""" + for target_p, p in zip( + self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True + ): + target_p.data.copy_( + p.data * self.config.critic_target_update_weight + + target_p.data * (1.0 - self.config.critic_target_update_weight) + ) + if self.config.num_discrete_actions is not None: + for target_p, p in zip( + self.discrete_critic_target.parameters(), + self.discrete_critic.parameters(), + strict=True, + ): + target_p.data.copy_( + p.data * self.config.critic_target_update_weight + + target_p.data * (1.0 - self.config.critic_target_update_weight) + ) + + def _prepare_forward_batch( + self, batch: BatchType, *, include_complementary_info: bool = True + ) -> dict[str, Any]: + observations = batch["state"] + next_observations = batch["next_state"] + observation_features, next_observation_features = self.get_observation_features( + observations, next_observations + ) + forward_batch: dict[str, Any] = { + ACTION: batch[ACTION], + "reward": batch["reward"], + "state": observations, + "next_state": next_observations, + "done": batch["done"], + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + if include_complementary_info and "complementary_info" in batch: + forward_batch["complementary_info"] = batch["complementary_info"] + return forward_batch + + def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]: + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + NOTE: + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ + actor_params = self.policy.get_optim_params()["actor"] + lr_scheduler = None + self.optimizers = { + "actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr), + "critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr), + "temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr), + } + if self.config.num_discrete_actions is not None: + self.optimizers["discrete_critic"] = torch.optim.Adam( + self.discrete_critic.parameters(), lr=self.config.critic_lr + ) + return self.optimizers, lr_scheduler + + def get_optimizers(self) -> dict[str, Optimizer]: + return self.optimizers + + def get_weights(self) -> dict[str, Any]: + """Send actor + discrete-critic state dicts.""" + state_dicts: dict[str, Any] = { + "policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"), + } + if self.config.num_discrete_actions is not None: + state_dicts["discrete_critic"] = move_state_dict_to_device( + self.discrete_critic.state_dict(), device="cpu" + ) + return state_dicts + + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + actor_sd = move_state_dict_to_device(weights["policy"], device=device) + self.policy.actor.load_state_dict(actor_sd) + if "discrete_critic" in weights and self.config.num_discrete_actions is not None: + dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device) + self.discrete_critic.load_state_dict(dc_sd) + + def get_observation_features( + self, observations: Tensor, next_observations: Tensor + ) -> tuple[Tensor | None, Tensor | None]: + """ + Get observation features from the policy encoder. It act as cache for the observation features. + when the encoder is frozen, the observation features are not updated. + We can save compute by caching the observation features. + + Args: + policy: The policy model + observations: The current observations + next_observations: The next observations + + Returns: + tuple: observation_features, next_observation_features + """ + + if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = self.policy.actor.encoder.get_cached_image_features(observations) + next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations) + + return observation_features, next_observation_features + + +class CriticHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + +class CriticEnsemble(nn.Module): + """ + CriticEnsemble wraps multiple CriticHead modules into an ensemble. + + Args: + encoder (SACObservationEncoder): encoder for observations. + ensemble (List[CriticHead]): list of critic heads. + init_final (float | None): optional initializer scale for final layers. + + Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. + """ + + def __init__( + self, + encoder: SACObservationEncoder, + ensemble: list[CriticHead], + init_final: float | None = None, + ): + super().__init__() + self.encoder = encoder + self.init_final = init_final + self.critics = nn.ModuleList(ensemble) + + def forward( + self, + observations: dict[str, torch.Tensor], + actions: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + + obs_enc = self.encoder(observations, cache=observation_features) + + inputs = torch.cat([obs_enc, actions], dim=-1) + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 81aa29c48..48ea2f3ef 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -96,8 +96,8 @@ class ReplayBuffer: Args: capacity (int): Maximum number of transitions to store in the buffer. device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). - state_keys (List[str]): The list of keys that appear in `state` and `next_state`. - image_augmentation_function (Optional[Callable]): A function that takes a batch of images + state_keys (list[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Callable | None): A function that takes a batch of images and returns a batch of augmented images. If None, a default augmentation function is used. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. @@ -638,7 +638,7 @@ class ReplayBuffer: If None, you must handle or define default keys. Returns: - transitions (List[Transition]): + transitions (list[Transition]): A list of Transition dictionaries with the same length as `dataset`. """ if state_keys is None: diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index 4345fed3c..54b4a6f07 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset( Args: original_dataset (LeRobotDataset): The source dataset. - crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + crop_params_dict (dict[str, Tuple[int, int, int, int]]): A dictionary mapping observation keys to crop parameters (top, left, height, width). new_repo_id (str): Repository id for the new dataset. new_dataset_root (str): The root directory where the new dataset will be written. - resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + resize_size (tuple[int, int], optional): The target size (height, width) after cropping. Defaults to (128, 128). Returns: diff --git a/src/lerobot/rl/data_sources/__init__.py b/src/lerobot/rl/data_sources/__init__.py new file mode 100644 index 000000000..4ac97ec1b --- /dev/null +++ b/src/lerobot/rl/data_sources/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer + +__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"] diff --git a/src/lerobot/rl/data_sources/data_mixer.py b/src/lerobot/rl/data_sources/data_mixer.py new file mode 100644 index 000000000..01c9055be --- /dev/null +++ b/src/lerobot/rl/data_sources/data_mixer.py @@ -0,0 +1,94 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from typing import Any + +from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions + +BatchType = dict[str, Any] + + +class DataMixer(abc.ABC): + """Abstract interface for all data mixing strategies. + + Subclasses must implement ``sample(batch_size)`` and may override + ``get_iterator`` for specialised iteration. + """ + + @abc.abstractmethod + def sample(self, batch_size: int) -> BatchType: + """Draw one batch of ``batch_size`` transitions.""" + ... + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """Infinite iterator that yields batches. + + The default implementation repeatedly calls ``self.sample()``. + Subclasses with underlying buffer iterators (async prefetch) + should override this for better throughput. + """ + while True: + yield self.sample(batch_size) + + +class OnlineOfflineMixer(DataMixer): + """Mixes transitions from an online and an optional offline replay buffer. + + When both buffers are present, each batch is constructed by sampling + ``ceil(batch_size * online_ratio)`` from the online buffer and the + remainder from the offline buffer, then concatenating. + + This mixer assumes both online and offline buffers are present. + """ + + def __init__( + self, + online_buffer: ReplayBuffer, + offline_buffer: ReplayBuffer | None = None, + online_ratio: float = 1.0, + ): + if not 0.0 <= online_ratio <= 1.0: + raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}") + self.online_buffer = online_buffer + self.offline_buffer = offline_buffer + self.online_ratio = online_ratio + + def sample(self, batch_size: int) -> BatchType: + if self.offline_buffer is None: + return self.online_buffer.sample(batch_size) + + n_online = max(1, int(batch_size * self.online_ratio)) + n_offline = batch_size - n_online + + online_batch = self.online_buffer.sample(n_online) + offline_batch = self.offline_buffer.sample(n_offline) + return concatenate_batch_transitions(online_batch, offline_batch) + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """Yield batches from online/offline mixed sampling.""" + while True: + yield self.sample(batch_size) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 2853fbcb3..39cd70b0b 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -64,10 +64,13 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy -from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.rl.algorithms.base import RLAlgorithm +from lerobot.rl.algorithms.factory import make_algorithm +from lerobot.rl.buffer import ReplayBuffer +from lerobot.rl.data_sources import OnlineOfflineMixer from lerobot.rl.process import ProcessSignalHandler +from lerobot.rl.trainer import RLTrainer from lerobot.rl.wandb_utils import WandBLogger from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 @@ -94,7 +97,7 @@ from lerobot.utils.train_utils import ( save_checkpoint, update_last_checkpoint, ) -from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device +from lerobot.utils.transition import move_transition_to_device from lerobot.utils.utils import ( format_big_number, init_logging, @@ -264,8 +267,8 @@ def add_actor_information_and_train( - Transfers transitions from the actor to the replay buffer. - Logs received interaction messages. - Ensures training begins only when the replay buffer has a sufficient number of transitions. - - Samples batches from the replay buffer and performs multiple critic updates. - - Periodically updates the actor, critic, and temperature optimizers. + - Delegates training updates to an ``RLAlgorithm``. + - Periodically pushes updated weights to actors. - Logs training statistics, including loss values and optimization frequency. NOTE: This function doesn't have a single responsibility, it should be split into multiple functions @@ -284,17 +287,13 @@ def add_actor_information_and_train( # of 7% device = get_safe_torch_device(try_device=cfg.policy.device, log=True) storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) - clip_grad_norm_value = cfg.policy.grad_clip_norm online_step_before_learning = cfg.policy.online_step_before_learning - utd_ratio = cfg.policy.utd_ratio fps = cfg.env.fps log_freq = cfg.log_freq save_freq = cfg.save_freq - policy_update_freq = cfg.policy.policy_update_freq policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps - async_prefetch = cfg.policy.async_prefetch # Initialize logging for multiprocessing if not use_threads(cfg): @@ -306,7 +305,7 @@ def add_actor_information_and_train( logging.info("Initializing policy") - policy: SACPolicy = make_policy( + policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) @@ -315,15 +314,21 @@ def add_actor_information_and_train( policy.train() - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + algorithm = make_algorithm( + policy=policy, + policy_cfg=cfg.policy, + algorithm_name=cfg.algorithm, + ) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + dataset_stats=cfg.policy.dataset_stats, + ) + + # Push initial policy weights to actors + push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm) last_time_policy_pushed = time.time() - optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) - - # If we are resuming, we need to load the training state - resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) - log_training_info(cfg=cfg, policy=policy) replay_buffer = initialize_replay_buffer(cfg, device, storage_device) @@ -336,21 +341,35 @@ def add_actor_information_and_train( device=device, storage_device=storage_device, ) - batch_size: int = batch_size // 2 # We will sample from both replay buffer + + # DataMixer: online-only or online/offline 50-50 mix + data_mixer = OnlineOfflineMixer( + online_buffer=replay_buffer, + offline_buffer=offline_replay_buffer, + online_ratio=cfg.online_ratio, + ) + # RLTrainer owns the iterator, preprocessor, and creates optimizers. + trainer = RLTrainer( + algorithm=algorithm, + data_mixer=data_mixer, + batch_size=batch_size, + preprocessor=preprocessor, + ) + + # If we are resuming, we need to load the training state + optimizers = algorithm.get_optimizers() + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) logging.info("Starting learner thread") interaction_message = None optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + algorithm.optimization_step = optimization_step interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 dataset_repo_id = None if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id - # Initialize iterators - online_iterator = None - offline_iterator = None - # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -380,180 +399,20 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue - if online_iterator is None: - online_iterator = replay_buffer.get_iterator( - batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 - ) - - if offline_replay_buffer is not None and offline_iterator is None: - offline_iterator = offline_replay_buffer.get_iterator( - batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 - ) - time_for_one_optimization_step = time.time() - for _ in range(utd_ratio - 1): - # Sample from the iterators - batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch[ACTION] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - - # Create a batch dictionary with all required elements for the forward method - forward_batch = { - ACTION: actions, - "reward": rewards, - "state": observations, - "next_state": next_observations, - "done": done, - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - "complementary_info": batch["complementary_info"], - } - - # Use the forward method for critic loss - critic_output = policy.forward(forward_batch, model="critic") - - # Main critic optimization - loss_critic = critic_output["loss_critic"] - optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ) - optimizers["critic"].step() - - # Discrete critic optimization (if available) - if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") - loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] - optimizers["discrete_critic"].zero_grad() - loss_discrete_critic.backward() - discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value - ) - optimizers["discrete_critic"].step() - - # Update target networks (main and discrete) - policy.update_target_networks() - - # Sample for the last update in the UTD ratio - batch = next(online_iterator) - - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch[ACTION] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - - # Create a batch dictionary with all required elements for the forward method - forward_batch = { - ACTION: actions, - "reward": rewards, - "state": observations, - "next_state": next_observations, - "done": done, - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - } - - critic_output = policy.forward(forward_batch, model="critic") - - loss_critic = critic_output["loss_critic"] - optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["critic"].step() - - # Initialize training info dictionary - training_infos = { - "loss_critic": loss_critic.item(), - "critic_grad_norm": critic_grad_norm, - } - - # Discrete critic optimization (if available) - if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") - loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] - optimizers["discrete_critic"].zero_grad() - loss_discrete_critic.backward() - discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["discrete_critic"].step() - - # Add discrete critic info to training info - training_infos["loss_discrete_critic"] = loss_discrete_critic.item() - training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm - - # Actor and temperature optimization (at specified frequency) - if optimization_step % policy_update_freq == 0: - for _ in range(policy_update_freq): - # Actor optimization - actor_output = policy.forward(forward_batch, model="actor") - loss_actor = actor_output["loss_actor"] - optimizers["actor"].zero_grad() - loss_actor.backward() - actor_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["actor"].step() - - # Add actor info to training info - training_infos["loss_actor"] = loss_actor.item() - training_infos["actor_grad_norm"] = actor_grad_norm - - # Temperature optimization - temperature_output = policy.forward(forward_batch, model="temperature") - loss_temperature = temperature_output["loss_temperature"] - optimizers["temperature"].zero_grad() - loss_temperature.backward() - temp_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=[policy.log_alpha], max_norm=clip_grad_norm_value - ).item() - optimizers["temperature"].step() - - # Add temperature info to training info - training_infos["loss_temperature"] = loss_temperature.item() - training_infos["temperature_grad_norm"] = temp_grad_norm - training_infos["temperature"] = policy.temperature + # One training step (trainer owns data_mixer iterator; algorithm owns UTD loop) + stats = trainer.training_step() # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm) last_time_policy_pushed = time.time() - # Update target networks (main and discrete) - policy.update_target_networks() + training_infos = stats.to_log_dict() # Log training metrics at specified intervals + optimization_step = algorithm.optimization_step if optimization_step % log_freq == 0: training_infos["replay_buffer_size"] = len(replay_buffer) if offline_replay_buffer is not None: @@ -581,7 +440,6 @@ def add_actor_information_and_train( custom_step_key="Optimization step", ) - optimization_step += 1 if optimization_step % log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") @@ -598,6 +456,8 @@ def add_actor_information_and_train( offline_replay_buffer=offline_replay_buffer, dataset_repo_id=dataset_repo_id, fps=fps, + preprocessor=preprocessor, + postprocessor=postprocessor, ) @@ -682,6 +542,8 @@ def save_training_checkpoint( offline_replay_buffer: ReplayBuffer | None = None, dataset_repo_id: str | None = None, fps: int = 30, + preprocessor=None, + postprocessor=None, ) -> None: """ Save training checkpoint and associated data. @@ -705,6 +567,8 @@ def save_training_checkpoint( offline_replay_buffer: Optional offline replay buffer to save dataset_repo_id: Repository ID for dataset fps: Frames per second for dataset + preprocessor: Optional preprocessor pipeline to save + postprocessor: Optional postprocessor pipeline to save """ logging.info(f"Checkpoint policy after step {optimization_step}") _num_digits = max(6, len(str(online_steps))) @@ -721,6 +585,8 @@ def save_training_checkpoint( policy=policy, optimizer=optimizers, scheduler=None, + preprocessor=preprocessor, + postprocessor=postprocessor, ) # Save interaction step manually @@ -758,58 +624,6 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module): - """ - Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. - - This function sets up Adam optimizers for: - - The **actor network**, ensuring that only relevant parameters are optimized. - - The **critic ensemble**, which evaluates the value function. - - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. - - It also initializes a learning rate scheduler, though currently, it is set to `None`. - - NOTE: - - If the encoder is shared, its parameters are excluded from the actor's optimization process. - - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. - - Args: - cfg: Configuration object containing hyperparameters. - policy (nn.Module): The policy model containing the actor, critic, and temperature components. - - Returns: - Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: - A tuple containing: - - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. - - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. - - """ - optimizer_actor = torch.optim.Adam( - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=cfg.policy.actor_lr, - ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - - if cfg.policy.num_discrete_actions is not None: - optimizer_discrete_critic = torch.optim.Adam( - params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr - ) - optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) - lr_scheduler = None - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - if cfg.policy.num_discrete_actions is not None: - optimizers["discrete_critic"] = optimizer_discrete_critic - return optimizers, lr_scheduler - - # Training setup functions @@ -1014,33 +828,6 @@ def initialize_offline_replay_buffer( # Utilities/Helpers functions -def get_observation_features( - policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor -) -> tuple[torch.Tensor | None, torch.Tensor | None]: - """ - Get observation features from the policy encoder. It act as cache for the observation features. - when the encoder is frozen, the observation features are not updated. - We can save compute by caching the observation features. - - Args: - policy: The policy model - observations: The current observations - next_observations: The next observations - - Returns: - tuple: observation_features, next_observation_features - """ - - if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: - return None, None - - with torch.no_grad(): - observation_features = policy.actor.encoder.get_cached_image_features(observations) - next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations) - - return observation_features, next_observation_features - - def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: return cfg.policy.concurrency.learner == "threads" @@ -1091,19 +878,11 @@ def check_nan_in_transition( return nan_detected -def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): +def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None: logging.debug("[LEARNER] Pushing actor policy to the queue") # Create a dictionary to hold all the state dicts - state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")} - - # Add discrete critic if it exists - if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None: - state_dicts["discrete_critic"] = move_state_dict_to_device( - policy.discrete_critic.state_dict(), device="cpu" - ) - logging.debug("[LEARNER] Including discrete critic in state dict push") - + state_dicts = algorithm.get_weights() state_bytes = state_to_bytes(state_dicts) parameters_queue.put(state_bytes) diff --git a/src/lerobot/rl/trainer.py b/src/lerobot/rl/trainer.py new file mode 100644 index 000000000..948ea9d5e --- /dev/null +++ b/src/lerobot/rl/trainer.py @@ -0,0 +1,103 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from lerobot.rl.algorithms.base import BatchType, RLAlgorithm +from lerobot.rl.algorithms.configs import TrainingStats +from lerobot.rl.data_sources.data_mixer import DataMixer + + +class RLTrainer: + """Unified training step orchestrator. + + Holds the algorithm, a DataMixer, and an optional preprocessor. + """ + + def __init__( + self, + algorithm: RLAlgorithm, + data_mixer: DataMixer, + batch_size: int, + *, + preprocessor: Any | None = None, + ): + self.algorithm = algorithm + self.data_mixer = data_mixer + self.batch_size = batch_size + self._preprocessor = preprocessor + + self._iterator: Iterator[BatchType] | None = None + + self.algorithm.make_optimizers_and_scheduler() + + def _build_data_iterator(self) -> Iterator[BatchType]: + """Create a fresh algorithm-configured iterator (optionally preprocessed).""" + raw = self.algorithm.configure_data_iterator( + data_mixer=self.data_mixer, + batch_size=self.batch_size, + ) + if self._preprocessor is not None: + return _PreprocessedIterator(raw, self._preprocessor) + return raw + + def reset_data_iterator(self) -> None: + """Discard the current iterator so it will be rebuilt lazily next step.""" + self._iterator = None + + def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None: + """Swap the active data mixer, optionally resetting the iterator.""" + self.data_mixer = data_mixer + if reset: + self.reset_data_iterator() + + def training_step(self) -> TrainingStats: + """Run one training step (algorithm-agnostic).""" + if self._iterator is None: + self._iterator = self._build_data_iterator() + return self.algorithm.update(self._iterator) + + +def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType: + """Apply policy preprocessing to RL observations only. + + This mirrors the pre-refactor SAC learner behavior where actions are left + unchanged and only state/next_state observations are normalized. + """ + observations = batch["state"] + next_observations = batch["next_state"] + batch["state"] = preprocessor.process_observation(observations) + batch["next_state"] = preprocessor.process_observation(next_observations) + + return batch + + +class _PreprocessedIterator: + """Iterator wrapper that preprocesses each sampled RL batch.""" + + __slots__ = ("_raw", "_preprocessor") + + def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None: + self._raw = raw_iterator + self._preprocessor = preprocessor + + def __iter__(self) -> _PreprocessedIterator: + return self + + def __next__(self) -> BatchType: + batch = next(self._raw) + return preprocess_rl_batch(self._preprocessor, batch) diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 02f6aebb3..dbf32e95d 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -95,6 +95,7 @@ def save_checkpoint( optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. + postprocessor: The postprocessor/pipeline to save. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index a62ef3ebb..a572ea9e1 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @@ -38,9 +37,6 @@ def test_classifier_output(): @require_package("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_binary_classifier_with_default_params(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -82,9 +78,6 @@ def test_binary_classifier_with_default_params(): @require_package("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_multiclass_classifier(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -124,9 +117,6 @@ def test_multiclass_classifier(): @require_package("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_default_device(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -139,9 +129,6 @@ def test_default_device(): @require_package("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_explicit_device_setup(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 11499ce30..784ec2fc8 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import pytest import torch from torch import Tensor, nn @@ -23,6 +21,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed @@ -138,41 +137,6 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i } -def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: - """Create optimizers for the SAC policy.""" - optimizer_actor = torch.optim.Adam( - # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=policy.config.actor_lr, - ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters(), - lr=policy.config.critic_lr, - ) - optimizer_temperature = torch.optim.Adam( - params=[policy.log_alpha], - lr=policy.config.critic_lr, - ) - - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - - if has_discrete_action: - optimizers["discrete_critic"] = torch.optim.Adam( - params=policy.discrete_critic.parameters(), - lr=policy.config.critic_lr, - ) - - return optimizers - - def create_default_config( state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False ) -> SACConfig: @@ -212,7 +176,6 @@ def create_config_with_visual_input( "std": torch.randn(3, 1, 1), } - # Let make tests a little bit faster config.state_encoder_hidden_dim = 32 config.latent_dim = 32 @@ -220,75 +183,112 @@ def create_config_with_visual_input( return config -@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): - batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) - config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) - +def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]: + """Helper to create policy + algorithm pair for tests that need critics.""" policy = SACPolicy(config=config) policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers_and_scheduler() + return algorithm, policy - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) policy.eval() + with torch.no_grad(): observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + # squeeze(0) removes batch dim when batch_size==1 + assert selected_action.shape[-1] == action_dim + + +def test_sac_policy_select_action_with_discrete(): + """select_action should return continuous + discrete actions.""" + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.num_discrete_actions = 3 + policy = SACPolicy(config=config) + policy.eval() + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=1, state_dim=10) + # Squeeze to unbatched (single observation) + observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()} + selected_action = policy.select_action(observation_batch) + assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete @pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) +def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) policy = SACPolicy(config=config) + policy.eval() + + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + with torch.no_grad(): + output = policy.forward(batch) + assert "action" in output + assert "log_prob" in output + assert "action_mean" in output + assert output["action"].shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + algorithm, policy = _make_algorithm(config) + + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.item() is not None + assert actor_loss.shape == () + algorithm.optimizers["actor"].zero_grad() + actor_loss.backward() + algorithm.optimizers["actor"].step() + + temp_loss = algorithm._compute_loss_temperature(forward_batch) + assert temp_loss.item() is not None + assert temp_loss.shape == () + algorithm.optimizers["temperature"].zero_grad() + temp_loss.backward() + algorithm.optimizers["temperature"].step() + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.item() is not None assert actor_loss.shape == () - + algorithm.optimizers["actor"].zero_grad() actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() + algorithm.optimizers["actor"].step() policy.eval() with torch.no_grad(): @@ -296,210 +296,181 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di batch_size=batch_size, state_dim=state_dim ) selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + assert selected_action.shape[-1] == action_dim -# Let's check best candidates for pretrained encoders @pytest.mark.parametrize( "batch_size,state_dim,action_dim,vision_encoder_name", [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], ) @pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_sac_policy_with_pretrained_encoder( batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str ): config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config.vision_encoder_name = vision_encoder_name - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - optimizers = make_optimizers(policy) + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.item() is not None assert actor_loss.shape == () -def test_sac_policy_with_shared_encoder(): +def test_sac_training_with_shared_encoder(): batch_size = 2 action_dim = 10 state_dim = 10 config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config.shared_encoder = True - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.shape == () - + algorithm.optimizers["actor"].zero_grad() actor_loss.backward() - optimizers["actor"].step() + algorithm.optimizers["actor"].step() -def test_sac_policy_with_discrete_critic(): +def test_sac_training_with_discrete_critic(): batch_size = 2 continuous_action_dim = 9 - full_action_dim = continuous_action_dim + 1 # the last action is discrete + full_action_dim = continuous_action_dim + 1 state_dim = 10 config = create_config_with_visual_input( state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True ) + config.num_discrete_actions = 5 - num_discrete_actions = 5 - config.num_discrete_actions = num_discrete_actions - - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - optimizers = make_optimizers(policy, has_discrete_action=True) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] - assert discrete_critic_loss.item() is not None + discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch) assert discrete_critic_loss.shape == () + algorithm.optimizers["discrete_critic"].zero_grad() discrete_critic_loss.backward() - optimizers["discrete_critic"].step() + algorithm.optimizers["discrete_critic"].step() - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.shape == () - + algorithm.optimizers["actor"].zero_grad() actor_loss.backward() - optimizers["actor"].step() + algorithm.optimizers["actor"].step() policy.eval() with torch.no_grad(): observation_batch = create_observation_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim ) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, full_action_dim) - - discrete_actions = selected_action[:, -1].long() - discrete_action_values = set(discrete_actions.tolist()) - - assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( - f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" - ) + # Policy.select_action now handles both continuous + discrete + selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()}) + assert selected_action.shape[-1] == continuous_action_dim + 1 -def test_sac_policy_with_default_entropy(): +def test_sac_algorithm_target_entropy(): config = create_default_config(continuous_action_dim=10, state_dim=10) - policy = SACPolicy(config=config) - assert policy.target_entropy == -5.0 + _, policy = _make_algorithm(config) + algo_config = SACAlgorithmConfig.from_policy_config(config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + assert algorithm.target_entropy == -5.0 -def test_sac_policy_default_target_entropy_with_discrete_action(): +def test_sac_algorithm_target_entropy_with_discrete_action(): config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + config.num_discrete_actions = 5 + algo_config = SACAlgorithmConfig.from_policy_config(config) policy = SACPolicy(config=config) - assert policy.target_entropy == -3.0 + algorithm = SACAlgorithm(policy=policy, config=algo_config) + assert algorithm.target_entropy == -3.5 -def test_sac_policy_with_predefined_entropy(): - config = create_default_config(state_dim=10, continuous_action_dim=6) - config.target_entropy = -3.5 +def test_sac_algorithm_temperature(): + import math - policy = SACPolicy(config=config) - assert policy.target_entropy == pytest.approx(-3.5) - - -def test_sac_policy_update_temperature(): - """Test that temperature property is always in sync with log_alpha.""" config = create_default_config(continuous_action_dim=10, state_dim=10) + algo_config = SACAlgorithmConfig.from_policy_config(config) policy = SACPolicy(config=config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) - assert policy.temperature == pytest.approx(1.0) - policy.log_alpha.data = torch.tensor([math.log(0.1)]) - # Temperature property automatically reflects log_alpha changes - assert policy.temperature == pytest.approx(0.1) + assert algorithm.temperature == pytest.approx(1.0) + algorithm.log_alpha.data = torch.tensor([math.log(0.1)]) + assert algorithm.temperature == pytest.approx(0.1) -def test_sac_policy_update_target_network(): +def test_sac_algorithm_update_target_network(): config = create_default_config(state_dim=10, continuous_action_dim=6) config.critic_target_update_weight = 1.0 - + algo_config = SACAlgorithmConfig.from_policy_config(config) policy = SACPolicy(config=config) - policy.train() + algorithm = SACAlgorithm(policy=policy, config=algo_config) - for p in policy.critic_ensemble.parameters(): + for p in algorithm.critic_ensemble.parameters(): p.data = torch.ones_like(p.data) - policy.update_target_networks() - for p in policy.critic_target.parameters(): - assert torch.allclose(p.data, torch.ones_like(p.data)), ( - f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" - ) + algorithm._update_target_networks() + for p in algorithm.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)) @pytest.mark.parametrize("num_critics", [1, 3]) -def test_sac_policy_with_critics_number_of_heads(num_critics: int): +def test_sac_algorithm_with_critics_number_of_heads(num_critics: int): batch_size = 2 action_dim = 10 state_dim = 10 config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config.num_critics = num_critics - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) - assert len(policy.critic_ensemble.critics) == num_critics + assert len(algorithm.critic_ensemble.critics) == num_critics batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() def test_sac_policy_save_and_load(tmp_path): + """Test that the policy can be saved and loaded from pretrained.""" root = tmp_path / "test_sac_save_and_load" state_dim = 10 @@ -513,34 +484,41 @@ def test_sac_policy_save_and_load(tmp_path): loaded_policy = SACPolicy.from_pretrained(root, config=config) loaded_policy.eval() - batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) with torch.no_grad(): with seeded_context(12): - # Collect policy values before saving - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) actions = policy.select_action(observation_batch) with seeded_context(12): - # Collect policy values after loading - loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] - loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] - loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] - loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) loaded_actions = loaded_policy.select_action(loaded_observation_batch) - assert policy.state_dict().keys() == loaded_policy.state_dict().keys() - for k in policy.state_dict(): - assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) - - # Compare values before and after saving and loading - # They should be the same - assert torch.allclose(cirtic_loss, loaded_cirtic_loss) - assert torch.allclose(actor_loss, loaded_actor_loss) - assert torch.allclose(temperature_loss, loaded_temperature_loss) assert torch.allclose(actions, loaded_actions) + + +def test_sac_policy_save_and_load_with_discrete_critic(tmp_path): + """Discrete critic should be saved/loaded as part of the policy.""" + root = tmp_path / "test_sac_save_and_load_discrete" + + state_dim = 10 + action_dim = 6 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_discrete_actions = 3 + policy = SACPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + + loaded_policy = SACPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + assert loaded_policy.discrete_critic is not None + dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 + + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index e13862d82..95d82db82 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -23,8 +23,9 @@ import torch from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -296,3 +297,171 @@ def test_end_to_end_parameters_flow(cfg, data_size): assert received_params.keys() == input_params.keys() for key in input_params: assert torch.allclose(received_params[key], input_params[key]) + + +# --------------------------------------------------------------------------- +# Regression test: learner algorithm integration (no gRPC required) +# --------------------------------------------------------------------------- + + +def test_learner_algorithm_wiring(): + """Verify that make_algorithm constructs an SACAlgorithm from config, + make_optimizers_and_scheduler() creates the right optimizers, update() works, and + get_weights() output is serializable.""" + from lerobot.policies.sac.modeling_sac import SACPolicy + from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithm + from lerobot.transport.utils import state_to_bytes + + state_dim = 10 + action_dim = 6 + + sac_cfg = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + use_torch_compile=False, + ) + sac_cfg.validate_features() + + policy = SACPolicy(config=sac_cfg) + policy.train() + + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + + optimizers = algorithm.make_optimizers_and_scheduler() + assert "actor" in optimizers + assert "critic" in optimizers + assert "temperature" in optimizers + + batch_size = 4 + + def batch_iterator(): + while True: + yield { + ACTION: torch.randn(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": {OBS_STATE: torch.randn(batch_size, state_dim)}, + "next_state": {OBS_STATE: torch.randn(batch_size, state_dim)}, + "done": torch.zeros(batch_size), + "complementary_info": {}, + } + + stats = algorithm.update(batch_iterator()) + assert "critic" in stats.losses + + # get_weights -> state_to_bytes round-trip + weights = algorithm.get_weights() + assert len(weights) > 0 + serialized = state_to_bytes(weights) + assert isinstance(serialized, bytes) + assert len(serialized) > 0 + + # RLTrainer with DataMixer + from lerobot.rl.buffer import ReplayBuffer + from lerobot.rl.data_sources import OnlineOfflineMixer + from lerobot.rl.trainer import RLTrainer + + replay_buffer = ReplayBuffer( + capacity=50, + device="cpu", + state_keys=[OBS_STATE], + storage_device="cpu", + use_drq=False, + ) + for _ in range(50): + replay_buffer.add( + state={OBS_STATE: torch.randn(state_dim)}, + action=torch.randn(action_dim), + reward=1.0, + next_state={OBS_STATE: torch.randn(state_dim)}, + done=False, + truncated=False, + ) + data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None) + trainer = RLTrainer( + algorithm=algorithm, + data_mixer=data_mixer, + batch_size=batch_size, + ) + trainer_stats = trainer.training_step() + assert "critic" in trainer_stats.losses + + +def test_initial_and_periodic_weight_push_consistency(): + """Both initial and periodic weight pushes should use algorithm.get_weights() + and produce identical structures.""" + from lerobot.policies.sac.modeling_sac import SACPolicy + from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes + + state_dim = 10 + action_dim = 6 + sac_cfg = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + use_torch_compile=False, + ) + sac_cfg.validate_features() + + policy = SACPolicy(config=sac_cfg) + policy.train() + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm.make_optimizers_and_scheduler() + + # Simulate initial push (same code path the learner now uses) + initial_weights = algorithm.get_weights() + initial_bytes = state_to_bytes(initial_weights) + + # Simulate periodic push + periodic_weights = algorithm.get_weights() + periodic_bytes = state_to_bytes(periodic_weights) + + initial_decoded = bytes_to_state_dict(initial_bytes) + periodic_decoded = bytes_to_state_dict(periodic_bytes) + + assert initial_decoded.keys() == periodic_decoded.keys() + + +def test_actor_side_algorithm_select_action_and_load_weights(): + """Simulate actor: create algorithm without optimizers, select_action, load_weights.""" + from lerobot.policies.sac.modeling_sac import SACPolicy + from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithm + + state_dim = 10 + action_dim = 6 + sac_cfg = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + use_torch_compile=False, + ) + sac_cfg.validate_features() + + # Actor side: no optimizers + policy = SACPolicy(config=sac_cfg) + policy.eval() + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + # select_action should work + obs = {OBS_STATE: torch.randn(state_dim)} + action = policy.select_action(obs) + assert action.shape == (action_dim,) + + # Simulate receiving weights from learner + fake_weights = algorithm.get_weights() + algorithm.load_weights(fake_weights, device="cpu") diff --git a/tests/rl/test_data_mixer.py b/tests/rl/test_data_mixer.py new file mode 100644 index 000000000..90e9e492f --- /dev/null +++ b/tests/rl/test_data_mixer.py @@ -0,0 +1,85 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer).""" + +import torch + +from lerobot.rl.buffer import ReplayBuffer +from lerobot.rl.data_sources import OnlineOfflineMixer +from lerobot.utils.constants import OBS_STATE + + +def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer: + buf = ReplayBuffer( + capacity=capacity, + device="cpu", + state_keys=[OBS_STATE], + storage_device="cpu", + use_drq=False, + ) + for i in range(capacity): + buf.add( + state={OBS_STATE: torch.randn(state_dim)}, + action=torch.randn(2), + reward=1.0, + next_state={OBS_STATE: torch.randn(state_dim)}, + done=bool(i % 10 == 9), + truncated=False, + ) + return buf + + +def test_online_only_mixer_sample(): + """OnlineOfflineMixer with no offline buffer returns online-only batches.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5) + batch = mixer.sample(batch_size=8) + assert batch["state"][OBS_STATE].shape[0] == 8 + assert batch["action"].shape[0] == 8 + assert batch["reward"].shape[0] == 8 + + +def test_online_only_mixer_ratio_one(): + """OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0) + batch = mixer.sample(batch_size=10) + assert batch["state"][OBS_STATE].shape[0] == 10 + + +def test_online_offline_mixer_sample(): + """OnlineOfflineMixer with two buffers returns concatenated batches.""" + online = _make_buffer(capacity=50) + offline = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer( + online_buffer=online, + offline_buffer=offline, + online_ratio=0.5, + ) + batch = mixer.sample(batch_size=10) + assert batch["state"][OBS_STATE].shape[0] == 10 + assert batch["action"].shape[0] == 10 + # 5 from online, 5 from offline (approx) + assert batch["reward"].shape[0] == 10 + + +def test_online_offline_mixer_iterator(): + """get_iterator yields batches of the requested size.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None) + it = mixer.get_iterator(batch_size=4, async_prefetch=False) + batch1 = next(it) + batch2 = next(it) + assert batch1["state"][OBS_STATE].shape[0] == 4 + assert batch2["state"][OBS_STATE].shape[0] == 4 diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py new file mode 100644 index 000000000..325c9727b --- /dev/null +++ b/tests/rl/test_sac_algorithm.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the RL algorithm abstraction and SACAlgorithm implementation.""" + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats +from lerobot.rl.algorithms.factory import make_algorithm +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.utils.random_utils import set_seed + +# --------------------------------------------------------------------------- +# Helpers (reuse patterns from tests/policies/test_sac_policy.py) +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def set_random_seed(): + set_seed(42) + + +def _make_sac_config( + state_dim: int = 10, + action_dim: int = 6, + num_discrete_actions: int | None = None, + utd_ratio: int = 1, + policy_update_freq: int = 1, + with_images: bool = False, +) -> SACConfig: + config = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + utd_ratio=utd_ratio, + policy_update_freq=policy_update_freq, + num_discrete_actions=num_discrete_actions, + use_torch_compile=False, + ) + if with_images: + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { + "mean": torch.randn(3, 1, 1).tolist(), + "std": torch.randn(3, 1, 1).abs().tolist(), + } + config.latent_dim = 32 + config.state_encoder_hidden_dim = 32 + config.validate_features() + return config + + +def _make_algorithm( + state_dim: int = 10, + action_dim: int = 6, + utd_ratio: int = 1, + policy_update_freq: int = 1, + num_discrete_actions: int | None = None, + with_images: bool = False, +) -> tuple[SACAlgorithm, SACPolicy]: + sac_cfg = _make_sac_config( + state_dim=state_dim, + action_dim=action_dim, + utd_ratio=utd_ratio, + policy_update_freq=policy_update_freq, + num_discrete_actions=num_discrete_actions, + with_images=with_images, + ) + policy = SACPolicy(config=sac_cfg) + policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers_and_scheduler() + return algorithm, policy + + +def _make_batch( + batch_size: int = 4, + state_dim: int = 10, + action_dim: int = 6, + with_images: bool = False, +) -> dict: + obs = {OBS_STATE: torch.randn(batch_size, state_dim)} + next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)} + if with_images: + obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84) + next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84) + return { + ACTION: torch.randn(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": obs, + "next_state": next_obs, + "done": torch.zeros(batch_size), + "complementary_info": {}, + } + + +def _batch_iterator(**batch_kwargs): + """Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator).""" + while True: + yield _make_batch(**batch_kwargs) + + +# =========================================================================== +# Registry / config tests +# =========================================================================== + + +def test_sac_algorithm_config_registered(): + """SACAlgorithmConfig should be discoverable through the registry.""" + assert "sac" in RLAlgorithmConfig.get_known_choices() + cls = RLAlgorithmConfig.get_choice_class("sac") + assert cls is SACAlgorithmConfig + + +def test_sac_algorithm_config_from_policy_config(): + """from_policy_config should copy relevant fields.""" + sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2) + algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) + assert algo_cfg.utd_ratio == 4 + assert algo_cfg.policy_update_freq == 2 + assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm + + +# =========================================================================== +# TrainingStats tests +# =========================================================================== + + +def test_training_stats_defaults(): + stats = TrainingStats() + assert stats.losses == {} + assert stats.grad_norms == {} + assert stats.extra == {} + + +# =========================================================================== +# get_weights +# =========================================================================== + + +def test_get_weights_returns_policy_state_dict(): + algorithm, policy = _make_algorithm() + weights = algorithm.get_weights() + for key in policy.state_dict(): + assert key in weights + assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu()) + + +def test_get_weights_includes_discrete_critic_when_present(): + algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6) + weights = algorithm.get_weights() + dc_keys = [k for k in weights if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 + + +def test_get_weights_excludes_discrete_critic_when_absent(): + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + dc_keys = [k for k in weights if k.startswith("discrete_critic.")] + assert len(dc_keys) == 0 + + +def test_get_weights_are_on_cpu(): + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + for key, tensor in weights.items(): + assert tensor.device == torch.device("cpu"), f"{key} is not on CPU" + + +# =========================================================================== +# select_action (lives on the policy, not the algorithm) +# =========================================================================== + + +def test_select_action_returns_correct_shape(): + action_dim = 6 + _, policy = _make_algorithm(state_dim=10, action_dim=action_dim) + policy.eval() + obs = {OBS_STATE: torch.randn(10)} + action = policy.select_action(obs) + assert action.shape == (action_dim,) + + +def test_select_action_with_discrete_critic(): + continuous_dim = 5 + _, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3) + policy.eval() + obs = {OBS_STATE: torch.randn(10)} + action = policy.select_action(obs) + assert action.shape == (continuous_dim + 1,) + + +# =========================================================================== +# update (single batch, utd_ratio=1) +# =========================================================================== + + +def test_update_returns_training_stats(): + algorithm, _ = _make_algorithm() + stats = algorithm.update(_batch_iterator()) + assert isinstance(stats, TrainingStats) + assert "critic" in stats.losses + assert isinstance(stats.losses["critic"], float) + + +def test_update_populates_actor_and_temperature_losses(): + """With policy_update_freq=1 and step 0, actor/temperature should be updated.""" + algorithm, _ = _make_algorithm(policy_update_freq=1) + stats = algorithm.update(_batch_iterator()) + assert "actor" in stats.losses + assert "temperature" in stats.losses + assert "temperature" in stats.extra + + +@pytest.mark.parametrize("policy_update_freq", [2, 3]) +def test_update_skips_actor_at_non_update_steps(policy_update_freq): + """Actor/temperature should only update when optimization_step % freq == 0.""" + algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq) + it = _batch_iterator() + + # Step 0: should update actor + stats_0 = algorithm.update(it) + assert "actor" in stats_0.losses + + # Step 1: should NOT update actor + stats_1 = algorithm.update(it) + assert "actor" not in stats_1.losses + + +def test_update_increments_optimization_step(): + algorithm, _ = _make_algorithm() + it = _batch_iterator() + assert algorithm.optimization_step == 0 + algorithm.update(it) + assert algorithm.optimization_step == 1 + algorithm.update(it) + assert algorithm.optimization_step == 2 + + +def test_update_with_discrete_critic(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete + assert "discrete_critic" in stats.losses + assert "discrete_critic" in stats.grad_norms + + +# =========================================================================== +# update with UTD ratio > 1 +# =========================================================================== + + +@pytest.mark.parametrize("utd_ratio", [2, 4]) +def test_update_with_utd_ratio(utd_ratio): + algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) + stats = algorithm.update(_batch_iterator()) + assert isinstance(stats, TrainingStats) + assert "critic" in stats.losses + assert algorithm.optimization_step == 1 + + +def test_update_utd_ratio_pulls_utd_batches(): + """next(batch_iterator) should be called exactly utd_ratio times.""" + utd_ratio = 3 + algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) + + call_count = 0 + + def counting_iterator(): + nonlocal call_count + while True: + call_count += 1 + yield _make_batch() + + algorithm.update(counting_iterator()) + assert call_count == utd_ratio + + +def test_update_utd_ratio_3_critic_warmup_changes_weights(): + """With utd_ratio=3, critic weights should change after update (3 critic steps).""" + algorithm, policy = _make_algorithm(utd_ratio=3) + + critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()} + + algorithm.update(_batch_iterator()) + + changed = False + for n, p in algorithm.critic_ensemble.named_parameters(): + if not torch.equal(p, critic_params_before[n]): + changed = True + break + assert changed, "Critic weights should have changed after UTD update" + + +# =========================================================================== +# get_observation_features +# =========================================================================== + + +def test_get_observation_features_returns_none_without_frozen_encoder(): + algorithm, _ = _make_algorithm(with_images=False) + obs = {OBS_STATE: torch.randn(4, 10)} + next_obs = {OBS_STATE: torch.randn(4, 10)} + feat, next_feat = algorithm.get_observation_features(obs, next_obs) + assert feat is None + assert next_feat is None + + +# =========================================================================== +# optimization_step setter +# =========================================================================== + + +def test_optimization_step_can_be_set_for_resume(): + algorithm, _ = _make_algorithm() + algorithm.optimization_step = 100 + assert algorithm.optimization_step == 100 + + +# =========================================================================== +# make_algorithm factory +# =========================================================================== + + +def test_make_algorithm_returns_sac_for_sac_policy(): + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + +def test_make_optimizers_creates_expected_keys(): + """make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers.""" + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + optimizers = algorithm.make_optimizers_and_scheduler() + assert "actor" in optimizers + assert "critic" in optimizers + assert "temperature" in optimizers + assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values()) + assert algorithm.get_optimizers() is optimizers + + +def test_actor_side_no_optimizers(): + """Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called.""" + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + +def test_make_algorithm_copies_config_fields(): + sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3) + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert algorithm.config.utd_ratio == 5 + assert algorithm.config.policy_update_freq == 3 + + +def test_make_algorithm_raises_for_unknown_type(): + class FakeConfig: + type = "unknown_algo" + + with pytest.raises(ValueError, match="No RLAlgorithmConfig"): + make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo") + + +# =========================================================================== +# load_weights (round-trip with get_weights) +# =========================================================================== + + +def test_load_weights_round_trip(): + """get_weights -> load_weights should restore identical parameters on a fresh policy.""" + algo_src, _ = _make_algorithm(state_dim=10, action_dim=6) + algo_src.update(_batch_iterator()) + + sac_cfg = _make_sac_config(state_dim=10, action_dim=6) + policy_dst = SACPolicy(config=sac_cfg) + algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config) + + weights = algo_src.get_weights() + algo_dst.load_weights(weights, device="cpu") + + for key in weights: + assert torch.equal( + algo_dst.policy.state_dict()[key].cpu(), + weights[key].cpu(), + ), f"Policy param '{key}' mismatch after load_weights" + + +def test_load_weights_round_trip_with_discrete_critic(): + algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + algo_src.update(_batch_iterator(action_dim=7)) + + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + policy_dst = SACPolicy(config=sac_cfg) + algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config) + + weights = algo_src.get_weights() + algo_dst.load_weights(weights, device="cpu") + + dc_keys = [k for k in weights if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 + for key in dc_keys: + assert torch.equal( + algo_dst.policy.state_dict()[key].cpu(), + weights[key].cpu(), + ), f"Discrete critic param '{key}' mismatch after load_weights" + + +def test_load_weights_ignores_missing_discrete_critic(): + """load_weights should not fail when weights lack discrete_critic on a non-discrete policy.""" + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + algorithm.load_weights(weights, device="cpu") + + +# =========================================================================== +# TrainingStats generic losses dict +# =========================================================================== + + +def test_training_stats_generic_losses(): + stats = TrainingStats( + losses={"loss_bc": 0.5, "loss_q": 1.2}, + extra={"temperature": 0.1}, + ) + assert stats.losses["loss_bc"] == 0.5 + assert stats.losses["loss_q"] == 1.2 + assert stats.extra["temperature"] == 0.1 + + +# =========================================================================== +# Registry-driven build_algorithm +# =========================================================================== + + +def test_build_algorithm_via_config(): + """SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm.""" + sac_cfg = _make_sac_config(utd_ratio=2) + algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + policy = SACPolicy(config=sac_cfg) + + algorithm = algo_config.build_algorithm(policy) + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.config.utd_ratio == 2 + + +def test_make_algorithm_uses_build_algorithm(): + """make_algorithm should delegate to config.build_algorithm (no hardcoded if/else).""" + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py new file mode 100644 index 000000000..47eaf6ad3 --- /dev/null +++ b/tests/rl/test_trainer.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor + +from lerobot.rl.algorithms.base import RLAlgorithm +from lerobot.rl.algorithms.configs import TrainingStats +from lerobot.rl.trainer import RLTrainer +from lerobot.utils.constants import ACTION, OBS_STATE + + +class _DummyRLAlgorithmConfig: + """Dummy config for testing.""" + + +class _DummyRLAlgorithm(RLAlgorithm): + config_class = _DummyRLAlgorithmConfig + name = "dummy_rl_algorithm" + + def __init__(self): + self.configure_calls = 0 + self.update_calls = 0 + + def select_action(self, observation: dict[str, Tensor]) -> Tensor: + return torch.zeros(1) + + def configure_data_iterator( + self, + data_mixer, + batch_size: int, + *, + async_prefetch: bool = True, + queue_size: int = 2, + ): + self.configure_calls += 1 + return data_mixer.get_iterator( + batch_size=batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + def make_optimizers_and_scheduler(self): + return {} + + def update(self, batch_iterator): + self.update_calls += 1 + _ = next(batch_iterator) + return TrainingStats(losses={"dummy": 1.0}) + + def load_weights(self, weights, device="cpu") -> None: + _ = (weights, device) + + +class _SimpleMixer: + def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2): + _ = (async_prefetch, queue_size) + while True: + yield { + "state": {OBS_STATE: torch.randn(batch_size, 3)}, + ACTION: torch.randn(batch_size, 2), + "reward": torch.randn(batch_size), + "next_state": {OBS_STATE: torch.randn(batch_size, 3)}, + "done": torch.zeros(batch_size), + "truncated": torch.zeros(batch_size), + "complementary_info": None, + } + + +def test_trainer_lazy_iterator_lifecycle_and_reset(): + algo = _DummyRLAlgorithm() + mixer = _SimpleMixer() + trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4) + + # First call builds iterator once. + trainer.training_step() + assert algo.configure_calls == 1 + assert algo.update_calls == 1 + + # Second call reuses existing iterator. + trainer.training_step() + assert algo.configure_calls == 1 + assert algo.update_calls == 2 + + # Explicit reset forces lazy rebuild on next step. + trainer.reset_data_iterator() + trainer.training_step() + assert algo.configure_calls == 2 + assert algo.update_calls == 3 + + +def test_trainer_set_data_mixer_resets_by_default(): + algo = _DummyRLAlgorithm() + mixer_a = _SimpleMixer() + mixer_b = _SimpleMixer() + trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2) + + trainer.training_step() + assert algo.configure_calls == 1 + + trainer.set_data_mixer(mixer_b, reset=True) + trainer.training_step() + assert algo.configure_calls == 2 + + +def test_algorithm_optimization_step_contract_defaults(): + algo = _DummyRLAlgorithm() + assert algo.optimization_step == 0 + algo.optimization_step = 11 + assert algo.optimization_step == 11