From 8d50be9faa3a18e02c401605aca27c80c63eabd3 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 2 Mar 2026 11:51:43 +0100 Subject: [PATCH] =?UTF-8?q?refactor:=20RL=20stack=20refactoring=20?= =?UTF-8?q?=E2=80=94=20RLAlgorithm,=20RLTrainer,=20DataMixer,=20and=20SAC?= =?UTF-8?q?=20restructuring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add RLAlgorithm base class and RLAlgorithmConfig with draccus.ChoiceRegistry - Add RLTrainer for unified training orchestration with iterator pattern - Add DataMixer and OnlineOfflineMixer for online/offline data mixing - Restructure SAC algorithm with batch iterator and factory pattern - Add observation normalization pre/post processors - Add comprehensive tests for all new components --- src/lerobot/configs/train.py | 12 + src/lerobot/processor/normalize_processor.py | 12 + src/lerobot/rl/__init__.py | 13 + src/lerobot/rl/actor.py | 76 ++- src/lerobot/rl/algorithms/__init__.py | 67 +++ src/lerobot/rl/algorithms/base.py | 163 +++++++ src/lerobot/rl/algorithms/sac.py | 262 ++++++++++ src/lerobot/rl/data_sources/__init__.py | 17 + src/lerobot/rl/data_sources/data_mixer.py | 94 ++++ src/lerobot/rl/learner.py | 372 ++++----------- src/lerobot/rl/trainer.py | 132 ++++++ tests/rl/test_actor_learner.py | 174 ++++++- tests/rl/test_data_mixer.py | 85 ++++ tests/rl/test_sac_algorithm.py | 474 +++++++++++++++++++ tests/rl/test_trainer.py | 115 +++++ 15 files changed, 1762 insertions(+), 306 deletions(-) create mode 100644 src/lerobot/rl/__init__.py create mode 100644 src/lerobot/rl/algorithms/__init__.py create mode 100644 src/lerobot/rl/algorithms/base.py create mode 100644 src/lerobot/rl/algorithms/sac.py create mode 100644 src/lerobot/rl/data_sources/__init__.py create mode 100644 src/lerobot/rl/data_sources/data_mixer.py create mode 100644 src/lerobot/rl/trainer.py create mode 100644 tests/rl/test_data_mixer.py create mode 100644 tests/rl/test_sac_algorithm.py create mode 100644 tests/rl/test_trainer.py diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 7a5eee77d..5c94d9ffc 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -211,3 +211,15 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig): # NOTE: In RL, we don't need an offline dataset # TODO: Make `TrainPipelineConfig.dataset` optional dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional + + # Algorithm name registered in RLAlgorithmConfig registry + algorithm: str = "sac" + + # Data mixer strategy name. Currently supports "online_offline" + mixer: str = "online_offline" + # Fraction sampled from online replay when using OnlineOfflineMixer + online_ratio: float = 0.5 + + # RL trainer iterator + async_prefetch: bool = True + queue_size: int = 2 diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 4769b91ac..1111f7fc9 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -131,6 +131,15 @@ class _NormalizationMixin: if self.dtype is None: self.dtype = torch.float32 self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() + + def _reshape_visual_stats(self) -> None: + """Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting.""" + for key, feature in self.features.items(): + if feature.type == FeatureType.VISUAL and key in self._tensor_stats: + for stat_name, stat_tensor in self._tensor_stats[key].items(): + if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1: + self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) def to( self, device: torch.device | str | None = None, dtype: torch.dtype | None = None @@ -149,6 +158,7 @@ class _NormalizationMixin: if dtype is not None: self.dtype = dtype self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() return self def state_dict(self) -> dict[str, Tensor]: @@ -198,6 +208,7 @@ class _NormalizationMixin: # Don't load from state_dict, keep the explicitly provided stats # But ensure _tensor_stats is properly initialized self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + self._reshape_visual_stats() return # Normal behavior: load stats from state_dict @@ -208,6 +219,7 @@ class _NormalizationMixin: self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( dtype=torch.float32, device=self.device ) + self._reshape_visual_stats() # Reconstruct the original stats dict from tensor stats for compatibility with to() method # and other functions that rely on self.stats diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py new file mode 100644 index 000000000..19b2f1409 --- /dev/null +++ b/src/lerobot/rl/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 7427633d2..03d48e775 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -60,9 +60,9 @@ from torch.multiprocessing import Event, Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies.factory import make_policy -from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.processor import TransitionKey +from lerobot.rl.algorithms import RLAlgorithm, make_algorithm from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 @@ -81,7 +81,6 @@ from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( Transition, - move_state_dict_to_device, move_transition_to_device, ) from lerobot.utils.utils import ( @@ -251,13 +250,45 @@ def act_with_policy( ### Instantiate the policy in both the actor and learner processes ### To avoid sending a SACPolicy object through the port, we create a policy instance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters - policy: SACPolicy = make_policy( + policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) policy = policy.eval() assert isinstance(policy, nn.Module) + algorithm = make_algorithm(policy=policy, policy_cfg=cfg.policy, algorithm_name=cfg.algorithm) + + # Build policy pre/post processors for observation normalization and action unnormalization + processor_kwargs = {} + postprocessor_kwargs = {} + if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: + processor_kwargs["dataset_stats"] = cfg.policy.dataset_stats + + if cfg.policy.pretrained_path is not None: + processor_kwargs["preprocessor_overrides"] = { + "device_processor": {"device": device.type}, + "normalizer_processor": { + "stats": cfg.policy.dataset_stats, + "features": {**policy.config.input_features, **policy.config.output_features}, + "norm_map": policy.config.normalization_mapping, + }, + } + postprocessor_kwargs["postprocessor_overrides"] = { + "unnormalizer_processor": { + "stats": cfg.policy.dataset_stats, + "features": policy.config.output_features, + "norm_map": policy.config.normalization_mapping, + }, + } + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, + ) + obs, info = online_env.reset() env_processor.reset() action_processor.reset() @@ -286,12 +317,27 @@ def act_with_policy( k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features } + # Preprocess observation (normalization, batch dim, device) + batch = {**observation} + batch = preprocessor(batch) + observation_for_inference = {k: v for k, v in batch.items() if k.startswith("observation.")} + # Time policy inference and check if it meets FPS requirement with policy_timer: - # Extract observation from transition for policy - action = policy.select_action(batch=observation) + action = algorithm.select_action(observation_for_inference) policy_fps = policy_timer.fps_last + # Postprocess action (unnormalization, move to cpu). + # Actions may include extra dimensions (e.g. discrete gripper) that are + # appended after the continuous action and should not be unnormalized. + expected_action_dim = cfg.policy.output_features["action"].shape[0] + if action.shape[-1] > expected_action_dim: + extra = action[..., expected_action_dim:] + action = postprocessor(action[..., :expected_action_dim]) + action = torch.cat([action, extra.cpu()], dim=-1) + else: + action = postprocessor(action) + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) # Use the new step function @@ -351,7 +397,7 @@ def act_with_policy( if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) + update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -649,12 +695,12 @@ def interactions_stream( # Policy functions -def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): +def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device): + """Load the latest weights from the learner via the algorithm's ``load_weights`` API.""" bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") state_dicts = bytes_to_state_dict(bytes_state_dict) - # TODO: check encoder parameter synchronization possible issues: # 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict # instead of the updated encoder params from critic (which is optimized separately) @@ -664,18 +710,8 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) # - Send critic's encoder state when shared_encoder=True # - Skip encoder params entirely when freeze_vision_encoder=True # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) - # Load actor state dict - actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) - policy.actor.load_state_dict(actor_state_dict) - - # Load discrete critic if present - if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts: - discrete_critic_state_dict = move_state_dict_to_device( - state_dicts["discrete_critic"], device=device - ) - policy.discrete_critic.load_state_dict(discrete_critic_state_dict) - logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") + algorithm.load_weights(state_dicts, device=device) # Utilities functions diff --git a/src/lerobot/rl/algorithms/__init__.py b/src/lerobot/rl/algorithms/__init__.py new file mode 100644 index 000000000..271d983c4 --- /dev/null +++ b/src/lerobot/rl/algorithms/__init__.py @@ -0,0 +1,67 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from lerobot.rl.algorithms.base import ( + RLAlgorithm, + RLAlgorithmConfig, + TrainingStats, +) +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig + + +def make_algorithm( + policy: torch.nn.Module, + policy_cfg, + *, + algorithm_name: str, +) -> RLAlgorithm: + """Construct an :class:`RLAlgorithm` from a policy and its config. + + Algorithm selection is explicit via ``algorithm_name`` (from + ``cfg.algorithm``). + + This is fully registry-driven — adding a new algorithm only requires + registering an ``RLAlgorithmConfig`` subclass; no changes here. + + The returned algorithm has **no optimizers** yet. On the learner side, + call ``algorithm.make_optimizers()`` afterwards to create them. On the + actor side (inference-only), leave them empty. + + Args: + policy: Instantiated policy (e.g. ``SACPolicy``). + policy_cfg: The policy's ``PreTrainedConfig`` with the hyper-parameters + expected by the algorithm config's ``from_policy_config`` class-method. + algorithm_name: Algorithm registry key to instantiate. + """ + known = RLAlgorithmConfig.get_known_choices() + if algorithm_name not in known: + raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}") + + config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name) + algo_config = config_cls.from_policy_config(policy_cfg) + return algo_config.build_algorithm(policy) + + +__all__ = [ + "RLAlgorithm", + "RLAlgorithmConfig", + "TrainingStats", + "SACAlgorithm", + "SACAlgorithmConfig", + "make_algorithm", +] diff --git a/src/lerobot/rl/algorithms/base.py b/src/lerobot/rl/algorithms/base.py new file mode 100644 index 000000000..839d2288f --- /dev/null +++ b/src/lerobot/rl/algorithms/base.py @@ -0,0 +1,163 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes for RL algorithms. + +Defines the abstract interface that every algorithm must implement, a registry +for algorithm configs, and a dataclass for training statistics. +""" + +from __future__ import annotations + +import abc +from collections.abc import Iterator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import draccus +import torch +from torch import Tensor +from torch.optim import Optimizer + +if TYPE_CHECKING: + from lerobot.rl.data_sources.data_mixer import DataMixer + +BatchType = dict[str, Any] + + +@dataclass +class TrainingStats: + """Returned by ``algorithm.update()`` for logging and checkpointing.""" + + # Generic containers for all algorithms + losses: dict[str, float] = field(default_factory=dict) + grad_norms: dict[str, float] = field(default_factory=dict) + extra: dict[str, float] = field(default_factory=dict) + + def to_log_dict(self) -> dict[str, float]: + """Flatten all stats into a single dict for logging.""" + + d: dict[str, float] = {} + for name, val in self.losses.items(): + d[name] = val + for name, val in self.grad_norms.items(): + d[f"{name}_grad_norm"] = val + for name, val in self.extra.items(): + d[name] = val + return d + + +@dataclass +class RLAlgorithmConfig(draccus.ChoiceRegistry): + """Registry for algorithm configs.""" + + def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm: + """Construct the :class:`RLAlgorithm` for this config. + + Must be overridden by every registered config subclass. + """ + raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()") + + @classmethod + def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig: + """Build an algorithm config from a policy config. + + Must be overridden by every registered config subclass. + """ + raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()") + + +class RLAlgorithm(abc.ABC): + """Base for all RL algorithms.""" + + @abc.abstractmethod + def select_action(self, observation: dict[str, Tensor]) -> Tensor: + """Select action(s) for rollout. + + Single-step policies (e.g. SAC) return shape ``(action_dim,)``; + chunking policies (e.g. VLA) return ``(chunk_size, action_dim)``. + """ + ... + + @abc.abstractmethod + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """One complete training step. + + The algorithm calls ``next(batch_iterator)`` as many times as it + needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches. + The iterator is owned by the trainer; the algorithm just consumes + from it. + """ + ... + + def configure_data_iterator( + self, + data_mixer: DataMixer, + batch_size: int, + *, + async_prefetch: bool = True, + queue_size: int = 2, + ) -> Iterator[BatchType]: + """Create the data iterator this algorithm needs. + + The default implementation uses the standard ``data_mixer.get_iterator()``. + Algorithms that need specialised sampling should override this method. + """ + return data_mixer.get_iterator( + batch_size=batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + def make_optimizers(self) -> dict[str, Optimizer]: + """Create, store, and return the optimizers needed for training. + + Called on the **learner** side after construction. Subclasses must + override this with algorithm-specific optimizer setup. + """ + return {} + + def get_optimizers(self) -> dict[str, Optimizer]: + """Return optimizers for checkpointing / external scheduling.""" + return {} + + @property + def optimization_step(self) -> int: + """Current learner optimization step. + + Part of the stable contract for checkpoint/resume. Algorithms can + either use this default storage or override for custom behavior. + """ + return getattr(self, "_optimization_step", 0) + + @optimization_step.setter + def optimization_step(self, value: int) -> None: + self._optimization_step = int(value) + + def get_weights(self) -> dict[str, Any]: + """State-dict(s) to push to actors.""" + return {} + + @abc.abstractmethod + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + """Load state-dict(s) received from the learner (inverse of ``get_weights``).""" + + @torch.no_grad() + def get_observation_features( + self, observations: Tensor, next_observations: Tensor + ) -> tuple[Tensor | None, Tensor | None]: + """Pre-compute observation features (e.g. frozen encoder cache). + + Returns ``(None, None)`` when caching is not applicable. + """ + return None, None diff --git a/src/lerobot/rl/algorithms/sac.py b/src/lerobot/rl/algorithms/sac.py new file mode 100644 index 000000000..c16ae48f6 --- /dev/null +++ b/src/lerobot/rl/algorithms/sac.py @@ -0,0 +1,262 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAC (Soft Actor-Critic) algorithm. + +This module encapsulates all SAC-specific training logic (critic, actor, +temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +import torch +from torch import Tensor +from torch.optim import Optimizer + +from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.rl.algorithms.base import ( + BatchType, + RLAlgorithm, + RLAlgorithmConfig, + TrainingStats, +) +from lerobot.utils.constants import ACTION +from lerobot.utils.transition import move_state_dict_to_device + + +@RLAlgorithmConfig.register_subclass("sac") +@dataclass +class SACAlgorithmConfig(RLAlgorithmConfig): + """SAC-specific hyper-parameters that control the update loop.""" + + utd_ratio: int = 1 + policy_update_freq: int = 1 + clip_grad_norm: float = 40.0 + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + + @classmethod + def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: + """Build from an existing ``SACConfig`` (cfg.policy) for backwards compat.""" + return cls( + utd_ratio=policy_cfg.utd_ratio, + policy_update_freq=policy_cfg.policy_update_freq, + clip_grad_norm=policy_cfg.grad_clip_norm, + actor_lr=policy_cfg.actor_lr, + critic_lr=policy_cfg.critic_lr, + ) + + def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: + return SACAlgorithm(policy=policy, config=self) + + +class SACAlgorithm(RLAlgorithm): + """Soft Actor-Critic with optional discrete-critic head. + + Owns the ``SACPolicy`` and its optimizers. + """ + + def __init__( + self, + policy: SACPolicy, + config: SACAlgorithmConfig, + ): + self.policy = policy + self.config = config + self.optimizers: dict[str, Optimizer] = {} + self._optimization_step: int = 0 + + @torch.no_grad() + def select_action(self, observation: dict[str, Tensor]) -> Tensor: + return self.policy.select_action(observation).squeeze(0) + + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """Run one full SAC update with UTD critic warm-up. + + Pulls ``utd_ratio`` batches from ``batch_iterator``. The first + ``utd_ratio - 1`` batches are used for critic-only warm-up steps; + the last batch drives the full update (critic + actor + temperature). + """ + for _ in range(self.config.utd_ratio - 1): + batch = next(batch_iterator) + forward_batch = self._prepare_forward_batch(batch) + + critic_output = self.policy.forward(forward_batch, model="critic") + loss_critic = critic_output["loss_critic"] + self.optimizers["critic"].zero_grad() + loss_critic.backward() + torch.nn.utils.clip_grad_norm_( + self.policy.critic_ensemble.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["critic"].step() + + if self.policy.config.num_discrete_actions is not None: + discrete_critic_output = self.policy.forward(forward_batch, model="discrete_critic") + loss_discrete = discrete_critic_output["loss_discrete_critic"] + self.optimizers["discrete_critic"].zero_grad() + loss_discrete.backward() + torch.nn.utils.clip_grad_norm_( + self.policy.discrete_critic.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["discrete_critic"].step() + self.policy.update_target_networks() + + batch = next(batch_iterator) + forward_batch = self._prepare_forward_batch(batch) + + critic_output = self.policy.forward(forward_batch, model="critic") + loss_critic = critic_output["loss_critic"] + self.optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + self.policy.critic_ensemble.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["critic"].step() + + critic_loss_val = loss_critic.item() + stats = TrainingStats( + losses={"critic": critic_loss_val}, + grad_norms={"critic": critic_grad_norm}, + ) + + if self.policy.config.num_discrete_actions is not None: + discrete_critic_output = self.policy.forward(forward_batch, model="discrete_critic") + loss_discrete = discrete_critic_output["loss_discrete_critic"] + self.optimizers["discrete_critic"].zero_grad() + loss_discrete.backward() + dc_grad = torch.nn.utils.clip_grad_norm_( + self.policy.discrete_critic.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["discrete_critic"].step() + dc_loss_val = loss_discrete.item() + stats.losses["discrete_critic"] = dc_loss_val + stats.grad_norms["discrete_critic"] = dc_grad + + if self._optimization_step % self.config.policy_update_freq == 0: + for _ in range(self.config.policy_update_freq): + actor_output = self.policy.forward(forward_batch, model="actor") + actor_loss = actor_output["loss_actor"] + self.optimizers["actor"].zero_grad() + actor_loss.backward() + actor_grad = torch.nn.utils.clip_grad_norm_( + self.policy.actor.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["actor"].step() + + temperature_output = self.policy.forward(forward_batch, model="temperature") + temp_loss = temperature_output["loss_temperature"] + self.optimizers["temperature"].zero_grad() + temp_loss.backward() + temp_grad = torch.nn.utils.clip_grad_norm_( + [self.policy.log_alpha], + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["temperature"].step() + + actor_loss_val = actor_loss.item() + temp_loss_val = temp_loss.item() + stats.losses["actor"] = actor_loss_val + stats.losses["temperature"] = temp_loss_val + stats.grad_norms["actor"] = actor_grad + stats.grad_norms["temperature"] = temp_grad + stats.extra["temperature"] = self.policy.temperature + + self.policy.update_target_networks() + + self._optimization_step += 1 + return stats + + def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]: + """Build the dict expected by ``SACPolicy.forward()`` from a batch.""" + observations = batch["state"] + next_observations = batch["next_state"] + + observation_features, next_observation_features = self.get_observation_features( + observations, next_observations + ) + forward_batch: dict[str, Any] = { + ACTION: batch[ACTION], + "reward": batch["reward"], + "state": observations, + "next_state": next_observations, + "done": batch["done"], + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + if "complementary_info" in batch: + forward_batch["complementary_info"] = batch["complementary_info"] + return forward_batch + + def make_optimizers(self) -> dict[str, Optimizer]: + """Create Adam optimizers for the SAC components and store them.""" + actor_params = [ + p + for n, p in self.policy.actor.named_parameters() + if not self.policy.config.shared_encoder or not n.startswith("encoder") + ] + self.optimizers = { + "actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr), + "critic": torch.optim.Adam(self.policy.critic_ensemble.parameters(), lr=self.config.critic_lr), + "temperature": torch.optim.Adam([self.policy.log_alpha], lr=self.config.critic_lr), + } + if self.policy.config.num_discrete_actions is not None: + self.optimizers["discrete_critic"] = torch.optim.Adam( + self.policy.discrete_critic.parameters(), lr=self.config.critic_lr + ) + return self.optimizers + + def get_optimizers(self) -> dict[str, Optimizer]: + return self.optimizers + + def get_weights(self) -> dict[str, Any]: + """State-dicts to push to the actor process.""" + out: dict[str, Any] = { + "policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"), + } + if hasattr(self.policy, "discrete_critic") and self.policy.discrete_critic is not None: + out["discrete_critic"] = move_state_dict_to_device( + self.policy.discrete_critic.state_dict(), device="cpu" + ) + return out + + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + """Load state-dict(s) received from the learner (inverse of ``get_weights``).""" + if "policy" in weights: + actor_state = move_state_dict_to_device(weights["policy"], device=device) + self.policy.actor.load_state_dict(actor_state) + if ( + "discrete_critic" in weights + and hasattr(self.policy, "discrete_critic") + and self.policy.discrete_critic is not None + ): + dc_state = move_state_dict_to_device(weights["discrete_critic"], device=device) + self.policy.discrete_critic.load_state_dict(dc_state) + + @torch.no_grad() + def get_observation_features( + self, observations: Tensor, next_observations: Tensor + ) -> tuple[Tensor | None, Tensor | None]: + if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder: + return None, None + observation_features = self.policy.actor.encoder.get_cached_image_features(observations) + next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations) + return observation_features, next_observation_features diff --git a/src/lerobot/rl/data_sources/__init__.py b/src/lerobot/rl/data_sources/__init__.py new file mode 100644 index 000000000..4ac97ec1b --- /dev/null +++ b/src/lerobot/rl/data_sources/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer + +__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"] diff --git a/src/lerobot/rl/data_sources/data_mixer.py b/src/lerobot/rl/data_sources/data_mixer.py new file mode 100644 index 000000000..01c9055be --- /dev/null +++ b/src/lerobot/rl/data_sources/data_mixer.py @@ -0,0 +1,94 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from typing import Any + +from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions + +BatchType = dict[str, Any] + + +class DataMixer(abc.ABC): + """Abstract interface for all data mixing strategies. + + Subclasses must implement ``sample(batch_size)`` and may override + ``get_iterator`` for specialised iteration. + """ + + @abc.abstractmethod + def sample(self, batch_size: int) -> BatchType: + """Draw one batch of ``batch_size`` transitions.""" + ... + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """Infinite iterator that yields batches. + + The default implementation repeatedly calls ``self.sample()``. + Subclasses with underlying buffer iterators (async prefetch) + should override this for better throughput. + """ + while True: + yield self.sample(batch_size) + + +class OnlineOfflineMixer(DataMixer): + """Mixes transitions from an online and an optional offline replay buffer. + + When both buffers are present, each batch is constructed by sampling + ``ceil(batch_size * online_ratio)`` from the online buffer and the + remainder from the offline buffer, then concatenating. + + This mixer assumes both online and offline buffers are present. + """ + + def __init__( + self, + online_buffer: ReplayBuffer, + offline_buffer: ReplayBuffer | None = None, + online_ratio: float = 1.0, + ): + if not 0.0 <= online_ratio <= 1.0: + raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}") + self.online_buffer = online_buffer + self.offline_buffer = offline_buffer + self.online_ratio = online_ratio + + def sample(self, batch_size: int) -> BatchType: + if self.offline_buffer is None: + return self.online_buffer.sample(batch_size) + + n_online = max(1, int(batch_size * self.online_ratio)) + n_offline = batch_size - n_online + + online_batch = self.online_buffer.sample(n_online) + offline_batch = self.offline_buffer.sample(n_offline) + return concatenate_batch_transitions(online_batch, offline_batch) + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """Yield batches from online/offline mixed sampling.""" + while True: + yield self.sample(batch_size) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index ee09ac9ac..ea19bb2fa 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -64,10 +64,12 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy -from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.rl.algorithms import make_algorithm +from lerobot.rl.buffer import ReplayBuffer +from lerobot.rl.data_sources import OnlineOfflineMixer from lerobot.rl.process import ProcessSignalHandler +from lerobot.rl.trainer import RLTrainer from lerobot.rl.wandb_utils import WandBLogger from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 @@ -93,7 +95,7 @@ from lerobot.utils.train_utils import ( save_checkpoint, update_last_checkpoint, ) -from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device +from lerobot.utils.transition import move_transition_to_device from lerobot.utils.utils import ( format_big_number, get_safe_torch_device, @@ -264,8 +266,8 @@ def add_actor_information_and_train( - Transfers transitions from the actor to the replay buffer. - Logs received interaction messages. - Ensures training begins only when the replay buffer has a sufficient number of transitions. - - Samples batches from the replay buffer and performs multiple critic updates. - - Periodically updates the actor, critic, and temperature optimizers. + - Delegates training updates to an ``RLAlgorithm`` (currently ``SACAlgorithm``). + - Periodically pushes updated weights to actors. - Logs training statistics, including loss values and optimization frequency. NOTE: This function doesn't have a single responsibility, it should be split into multiple functions @@ -284,17 +286,15 @@ def add_actor_information_and_train( # of 7% device = get_safe_torch_device(try_device=cfg.policy.device, log=True) storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) - clip_grad_norm_value = cfg.policy.grad_clip_norm online_step_before_learning = cfg.policy.online_step_before_learning - utd_ratio = cfg.policy.utd_ratio fps = cfg.env.fps log_freq = cfg.log_freq save_freq = cfg.save_freq - policy_update_freq = cfg.policy.policy_update_freq policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps - async_prefetch = cfg.policy.async_prefetch + async_prefetch = cfg.async_prefetch + queue_size = cfg.queue_size # Initialize logging for multiprocessing if not use_threads(cfg): @@ -306,7 +306,7 @@ def add_actor_information_and_train( logging.info("Initializing policy") - policy: SACPolicy = make_policy( + policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) @@ -315,19 +315,51 @@ def add_actor_information_and_train( policy.train() - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + algorithm = make_algorithm( + policy=policy, + policy_cfg=cfg.policy, + algorithm_name=cfg.algorithm, + ) + # Build policy preprocessor for batch normalization during training + processor_kwargs = {} + postprocessor_kwargs = {} + if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: + processor_kwargs["dataset_stats"] = cfg.policy.dataset_stats + + if cfg.policy.pretrained_path is not None: + processor_kwargs["preprocessor_overrides"] = { + "device_processor": {"device": device.type}, + "normalizer_processor": { + "stats": cfg.policy.dataset_stats, + "features": {**policy.config.input_features, **policy.config.output_features}, + "norm_map": policy.config.normalization_mapping, + }, + } + postprocessor_kwargs["postprocessor_overrides"] = { + "unnormalizer_processor": { + "stats": cfg.policy.dataset_stats, + "features": policy.config.output_features, + "norm_map": policy.config.normalization_mapping, + }, + } + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, + ) + + # Push initial policy weights to actors (same path as periodic push) + state_bytes = state_to_bytes(algorithm.get_weights()) + parameters_queue.put(state_bytes) last_time_policy_pushed = time.time() - optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) - - # If we are resuming, we need to load the training state - resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) - log_training_info(cfg=cfg, policy=policy) replay_buffer = initialize_replay_buffer(cfg, device, storage_device) - batch_size = cfg.batch_size + total_batch_size = cfg.batch_size offline_replay_buffer = None if cfg.dataset is not None: @@ -336,21 +368,38 @@ def add_actor_information_and_train( device=device, storage_device=storage_device, ) - batch_size: int = batch_size // 2 # We will sample from both replay buffer + + # DataMixer: online-only or online/offline 50-50 mix + data_mixer = OnlineOfflineMixer( + online_buffer=replay_buffer, + offline_buffer=offline_replay_buffer, + online_ratio=cfg.online_ratio, + ) + # RLTrainer owns the iterator, preprocessor, and creates optimizers. + trainer = RLTrainer( + algorithm=algorithm, + data_mixer=data_mixer, + batch_size=total_batch_size, + preprocessor=preprocessor, + action_dim=cfg.policy.output_features["action"].shape[0], + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + # If we are resuming, we need to load the training state + optimizers = algorithm.get_optimizers() + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) logging.info("Starting learner thread") interaction_message = None optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + algorithm.optimization_step = optimization_step interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 dataset_repo_id = None if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id - # Initialize iterators - online_iterator = None - offline_iterator = None - # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -380,180 +429,22 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue - if online_iterator is None: - online_iterator = replay_buffer.get_iterator( - batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 - ) - - if offline_replay_buffer is not None and offline_iterator is None: - offline_iterator = offline_replay_buffer.get_iterator( - batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 - ) - time_for_one_optimization_step = time.time() - for _ in range(utd_ratio - 1): - # Sample from the iterators - batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch[ACTION] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - - # Create a batch dictionary with all required elements for the forward method - forward_batch = { - ACTION: actions, - "reward": rewards, - "state": observations, - "next_state": next_observations, - "done": done, - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - "complementary_info": batch["complementary_info"], - } - - # Use the forward method for critic loss - critic_output = policy.forward(forward_batch, model="critic") - - # Main critic optimization - loss_critic = critic_output["loss_critic"] - optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ) - optimizers["critic"].step() - - # Discrete critic optimization (if available) - if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") - loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] - optimizers["discrete_critic"].zero_grad() - loss_discrete_critic.backward() - discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value - ) - optimizers["discrete_critic"].step() - - # Update target networks (main and discrete) - policy.update_target_networks() - - # Sample for the last update in the UTD ratio - batch = next(online_iterator) - - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch[ACTION] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - - # Create a batch dictionary with all required elements for the forward method - forward_batch = { - ACTION: actions, - "reward": rewards, - "state": observations, - "next_state": next_observations, - "done": done, - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - } - - critic_output = policy.forward(forward_batch, model="critic") - - loss_critic = critic_output["loss_critic"] - optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["critic"].step() - - # Initialize training info dictionary - training_infos = { - "loss_critic": loss_critic.item(), - "critic_grad_norm": critic_grad_norm, - } - - # Discrete critic optimization (if available) - if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") - loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] - optimizers["discrete_critic"].zero_grad() - loss_discrete_critic.backward() - discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["discrete_critic"].step() - - # Add discrete critic info to training info - training_infos["loss_discrete_critic"] = loss_discrete_critic.item() - training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm - - # Actor and temperature optimization (at specified frequency) - if optimization_step % policy_update_freq == 0: - for _ in range(policy_update_freq): - # Actor optimization - actor_output = policy.forward(forward_batch, model="actor") - loss_actor = actor_output["loss_actor"] - optimizers["actor"].zero_grad() - loss_actor.backward() - actor_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["actor"].step() - - # Add actor info to training info - training_infos["loss_actor"] = loss_actor.item() - training_infos["actor_grad_norm"] = actor_grad_norm - - # Temperature optimization - temperature_output = policy.forward(forward_batch, model="temperature") - loss_temperature = temperature_output["loss_temperature"] - optimizers["temperature"].zero_grad() - loss_temperature.backward() - temp_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=[policy.log_alpha], max_norm=clip_grad_norm_value - ).item() - optimizers["temperature"].step() - - # Add temperature info to training info - training_infos["loss_temperature"] = loss_temperature.item() - training_infos["temperature_grad_norm"] = temp_grad_norm - training_infos["temperature"] = policy.temperature + # One training step (trainer owns data_mixer iterator; algorithm owns UTD loop) + stats = trainer.training_step() # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + state_dicts = algorithm.get_weights() + state_bytes = state_to_bytes(state_dicts) + parameters_queue.put(state_bytes) last_time_policy_pushed = time.time() - # Update target networks (main and discrete) - policy.update_target_networks() + training_infos = stats.to_log_dict() # Log training metrics at specified intervals + optimization_step = algorithm.optimization_step if optimization_step % log_freq == 0: training_infos["replay_buffer_size"] = len(replay_buffer) if offline_replay_buffer is not None: @@ -581,7 +472,6 @@ def add_actor_information_and_train( custom_step_key="Optimization step", ) - optimization_step += 1 if optimization_step % log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") @@ -598,6 +488,8 @@ def add_actor_information_and_train( offline_replay_buffer=offline_replay_buffer, dataset_repo_id=dataset_repo_id, fps=fps, + preprocessor=preprocessor, + postprocessor=postprocessor, ) @@ -682,6 +574,8 @@ def save_training_checkpoint( offline_replay_buffer: ReplayBuffer | None = None, dataset_repo_id: str | None = None, fps: int = 30, + preprocessor=None, + postprocessor=None, ) -> None: """ Save training checkpoint and associated data. @@ -705,6 +599,8 @@ def save_training_checkpoint( offline_replay_buffer: Optional offline replay buffer to save dataset_repo_id: Repository ID for dataset fps: Frames per second for dataset + preprocessor: Optional preprocessor pipeline to save + postprocessor: Optional postprocessor pipeline to save """ logging.info(f"Checkpoint policy after step {optimization_step}") _num_digits = max(6, len(str(online_steps))) @@ -721,6 +617,8 @@ def save_training_checkpoint( policy=policy, optimizer=optimizers, scheduler=None, + preprocessor=preprocessor, + postprocessor=postprocessor, ) # Save interaction step manually @@ -758,58 +656,6 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module): - """ - Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. - - This function sets up Adam optimizers for: - - The **actor network**, ensuring that only relevant parameters are optimized. - - The **critic ensemble**, which evaluates the value function. - - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. - - It also initializes a learning rate scheduler, though currently, it is set to `None`. - - NOTE: - - If the encoder is shared, its parameters are excluded from the actor's optimization process. - - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. - - Args: - cfg: Configuration object containing hyperparameters. - policy (nn.Module): The policy model containing the actor, critic, and temperature components. - - Returns: - Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: - A tuple containing: - - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. - - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. - - """ - optimizer_actor = torch.optim.Adam( - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=cfg.policy.actor_lr, - ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - - if cfg.policy.num_discrete_actions is not None: - optimizer_discrete_critic = torch.optim.Adam( - params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr - ) - optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) - lr_scheduler = None - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - if cfg.policy.num_discrete_actions is not None: - optimizers["discrete_critic"] = optimizer_discrete_critic - return optimizers, lr_scheduler - - # Training setup functions @@ -1014,33 +860,6 @@ def initialize_offline_replay_buffer( # Utilities/Helpers functions -def get_observation_features( - policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor -) -> tuple[torch.Tensor | None, torch.Tensor | None]: - """ - Get observation features from the policy encoder. It act as cache for the observation features. - when the encoder is frozen, the observation features are not updated. - We can save compute by caching the observation features. - - Args: - policy: The policy model - observations: The current observations - next_observations: The next observations - - Returns: - tuple: observation_features, next_observation_features - """ - - if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: - return None, None - - with torch.no_grad(): - observation_features = policy.actor.encoder.get_cached_image_features(observations) - next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations) - - return observation_features, next_observation_features - - def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: return cfg.policy.concurrency.learner == "threads" @@ -1091,23 +910,6 @@ def check_nan_in_transition( return nan_detected -def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): - logging.debug("[LEARNER] Pushing actor policy to the queue") - - # Create a dictionary to hold all the state dicts - state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")} - - # Add discrete critic if it exists - if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None: - state_dicts["discrete_critic"] = move_state_dict_to_device( - policy.discrete_critic.state_dict(), device="cpu" - ) - logging.debug("[LEARNER] Including discrete critic in state dict push") - - state_bytes = state_to_bytes(state_dicts) - parameters_queue.put(state_bytes) - - def process_interaction_message( message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None ): diff --git a/src/lerobot/rl/trainer.py b/src/lerobot/rl/trainer.py new file mode 100644 index 000000000..ba6f84cda --- /dev/null +++ b/src/lerobot/rl/trainer.py @@ -0,0 +1,132 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +import torch + +from lerobot.rl.algorithms.base import ( + BatchType, + RLAlgorithm, + TrainingStats, +) +from lerobot.rl.data_sources.data_mixer import DataMixer +from lerobot.utils.constants import ACTION + + +def preprocess_rl_batch(preprocessor: Any, batch: BatchType, *, action_dim: int | None = None) -> BatchType: + """Apply a policy preprocessor to an RL batch.""" + observations = batch["state"] + next_observations = batch["next_state"] + actions = batch[ACTION] + + extra_action = None + if action_dim is not None and actions.shape[-1] > action_dim: + extra_action = actions[..., action_dim:] + actions = actions[..., :action_dim] + + obs_action = {**observations, ACTION: actions} + obs_action = preprocessor(obs_action) + batch["state"] = {k: v for k, v in obs_action.items() if k.startswith("observation.")} + batch[ACTION] = obs_action[ACTION] + + if extra_action is not None: + batch[ACTION] = torch.cat([batch[ACTION], extra_action], dim=-1) + + next_obs = {**next_observations} + next_obs = preprocessor(next_obs) + batch["next_state"] = {k: v for k, v in next_obs.items() if k.startswith("observation.")} + + return batch + + +class _PreprocessedIterator: + """Iterator wrapper that preprocesses each sampled RL batch.""" + + __slots__ = ("_raw", "_preprocessor", "_action_dim") + + def __init__( + self, raw_iterator: Iterator[BatchType], preprocessor: Any, action_dim: int | None = None + ) -> None: + self._raw = raw_iterator + self._preprocessor = preprocessor + self._action_dim = action_dim + + def __iter__(self) -> _PreprocessedIterator: + return self + + def __next__(self) -> BatchType: + batch = next(self._raw) + return preprocess_rl_batch(self._preprocessor, batch, action_dim=self._action_dim) + + +class RLTrainer: + """Unified training step orchestrator. + + Holds the algorithm, a DataMixer, and an optional preprocessor. + """ + + def __init__( + self, + algorithm: RLAlgorithm, + data_mixer: DataMixer, + batch_size: int, + *, + preprocessor: Any | None = None, + action_dim: int | None = None, + async_prefetch: bool = True, + queue_size: int = 2, + ): + self.algorithm = algorithm + self.data_mixer = data_mixer + self.batch_size = batch_size + self._preprocessor = preprocessor + self._action_dim = action_dim + self.async_prefetch = async_prefetch + self.queue_size = queue_size + + self._iterator: Iterator[BatchType] | None = None + + self.algorithm.make_optimizers() + + def _build_data_iterator(self) -> Iterator[BatchType]: + """Create a fresh algorithm-configured iterator (optionally preprocessed).""" + raw = self.algorithm.configure_data_iterator( + data_mixer=self.data_mixer, + batch_size=self.batch_size, + async_prefetch=self.async_prefetch, + queue_size=self.queue_size, + ) + if self._preprocessor is not None: + return _PreprocessedIterator(raw, self._preprocessor, self._action_dim) + return raw + + def reset_data_iterator(self) -> None: + """Discard the current iterator so it will be rebuilt lazily next step.""" + self._iterator = None + + def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None: + """Swap the active data mixer, optionally resetting the iterator.""" + self.data_mixer = data_mixer + if reset: + self.reset_data_iterator() + + def training_step(self) -> TrainingStats: + """Run one training step (algorithm-agnostic).""" + if self._iterator is None: + self._iterator = self._build_data_iterator() + return self.algorithm.update(self._iterator) diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index e13862d82..7c4dd25e7 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -23,8 +23,9 @@ import torch from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -296,3 +297,174 @@ def test_end_to_end_parameters_flow(cfg, data_size): assert received_params.keys() == input_params.keys() for key in input_params: assert torch.allclose(received_params[key], input_params[key]) + + +# --------------------------------------------------------------------------- +# Regression test: learner algorithm integration (no gRPC required) +# --------------------------------------------------------------------------- + + +def test_learner_algorithm_wiring(): + """Verify that make_algorithm constructs an SACAlgorithm from config, + make_optimizers() creates the right optimizers, update() works, and + get_weights() output is serializable.""" + from lerobot.policies.sac.modeling_sac import SACPolicy + from lerobot.rl.algorithms import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithm + from lerobot.transport.utils import state_to_bytes + + state_dim = 10 + action_dim = 6 + + sac_cfg = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + use_torch_compile=False, + ) + sac_cfg.validate_features() + + policy = SACPolicy(config=sac_cfg) + policy.train() + + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + + optimizers = algorithm.make_optimizers() + assert "actor" in optimizers + assert "critic" in optimizers + assert "temperature" in optimizers + + batch_size = 4 + + def batch_iterator(): + while True: + yield { + ACTION: torch.randn(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": {OBS_STATE: torch.randn(batch_size, state_dim)}, + "next_state": {OBS_STATE: torch.randn(batch_size, state_dim)}, + "done": torch.zeros(batch_size), + "complementary_info": {}, + } + + stats = algorithm.update(batch_iterator()) + assert "critic" in stats.losses + + # get_weights -> state_to_bytes round-trip + weights = algorithm.get_weights() + assert "policy" in weights + serialized = state_to_bytes(weights) + assert isinstance(serialized, bytes) + assert len(serialized) > 0 + + # RLTrainer with DataMixer + from lerobot.rl.buffer import ReplayBuffer + from lerobot.rl.data_sources import OnlineOfflineMixer + from lerobot.rl.trainer import RLTrainer + + replay_buffer = ReplayBuffer( + capacity=50, + device="cpu", + state_keys=[OBS_STATE], + storage_device="cpu", + use_drq=False, + ) + for _ in range(50): + replay_buffer.add( + state={OBS_STATE: torch.randn(state_dim)}, + action=torch.randn(action_dim), + reward=1.0, + next_state={OBS_STATE: torch.randn(state_dim)}, + done=False, + truncated=False, + ) + data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None) + trainer = RLTrainer( + algorithm=algorithm, + data_mixer=data_mixer, + batch_size=batch_size, + async_prefetch=False, + ) + trainer_stats = trainer.training_step() + assert "critic" in trainer_stats.losses + + +def test_initial_and_periodic_weight_push_consistency(): + """Both initial and periodic weight pushes should use algorithm.get_weights() + and produce identical structures.""" + from lerobot.policies.sac.modeling_sac import SACPolicy + from lerobot.rl.algorithms import make_algorithm + from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes + + state_dim = 10 + action_dim = 6 + sac_cfg = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + use_torch_compile=False, + ) + sac_cfg.validate_features() + + policy = SACPolicy(config=sac_cfg) + policy.train() + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm.make_optimizers() + + # Simulate initial push (same code path the learner now uses) + initial_weights = algorithm.get_weights() + initial_bytes = state_to_bytes(initial_weights) + + # Simulate periodic push + periodic_weights = algorithm.get_weights() + periodic_bytes = state_to_bytes(periodic_weights) + + initial_decoded = bytes_to_state_dict(initial_bytes) + periodic_decoded = bytes_to_state_dict(periodic_bytes) + + assert initial_decoded.keys() == periodic_decoded.keys() + for key in initial_decoded: + assert initial_decoded[key].keys() == periodic_decoded[key].keys() + + +def test_actor_side_algorithm_select_action_and_load_weights(): + """Simulate actor: create algorithm without optimizers, select_action, load_weights.""" + from lerobot.policies.sac.modeling_sac import SACPolicy + from lerobot.rl.algorithms import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithm + + state_dim = 10 + action_dim = 6 + sac_cfg = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + use_torch_compile=False, + ) + sac_cfg.validate_features() + + # Actor side: no optimizers + policy = SACPolicy(config=sac_cfg) + policy.eval() + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + # select_action should work + obs = {OBS_STATE: torch.randn(state_dim)} + action = algorithm.select_action(obs) + assert action.shape == (action_dim,) + + # Simulate receiving weights from learner + fake_weights = algorithm.get_weights() + algorithm.load_weights(fake_weights, device="cpu") diff --git a/tests/rl/test_data_mixer.py b/tests/rl/test_data_mixer.py new file mode 100644 index 000000000..90e9e492f --- /dev/null +++ b/tests/rl/test_data_mixer.py @@ -0,0 +1,85 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer).""" + +import torch + +from lerobot.rl.buffer import ReplayBuffer +from lerobot.rl.data_sources import OnlineOfflineMixer +from lerobot.utils.constants import OBS_STATE + + +def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer: + buf = ReplayBuffer( + capacity=capacity, + device="cpu", + state_keys=[OBS_STATE], + storage_device="cpu", + use_drq=False, + ) + for i in range(capacity): + buf.add( + state={OBS_STATE: torch.randn(state_dim)}, + action=torch.randn(2), + reward=1.0, + next_state={OBS_STATE: torch.randn(state_dim)}, + done=bool(i % 10 == 9), + truncated=False, + ) + return buf + + +def test_online_only_mixer_sample(): + """OnlineOfflineMixer with no offline buffer returns online-only batches.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5) + batch = mixer.sample(batch_size=8) + assert batch["state"][OBS_STATE].shape[0] == 8 + assert batch["action"].shape[0] == 8 + assert batch["reward"].shape[0] == 8 + + +def test_online_only_mixer_ratio_one(): + """OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0) + batch = mixer.sample(batch_size=10) + assert batch["state"][OBS_STATE].shape[0] == 10 + + +def test_online_offline_mixer_sample(): + """OnlineOfflineMixer with two buffers returns concatenated batches.""" + online = _make_buffer(capacity=50) + offline = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer( + online_buffer=online, + offline_buffer=offline, + online_ratio=0.5, + ) + batch = mixer.sample(batch_size=10) + assert batch["state"][OBS_STATE].shape[0] == 10 + assert batch["action"].shape[0] == 10 + # 5 from online, 5 from offline (approx) + assert batch["reward"].shape[0] == 10 + + +def test_online_offline_mixer_iterator(): + """get_iterator yields batches of the requested size.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None) + it = mixer.get_iterator(batch_size=4, async_prefetch=False) + batch1 = next(it) + batch2 = next(it) + assert batch1["state"][OBS_STATE].shape[0] == 4 + assert batch2["state"][OBS_STATE].shape[0] == 4 diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py new file mode 100644 index 000000000..63a5eb520 --- /dev/null +++ b/tests/rl/test_sac_algorithm.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the RL algorithm abstraction and SACAlgorithm implementation.""" + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.rl.algorithms import make_algorithm +from lerobot.rl.algorithms.base import RLAlgorithmConfig, TrainingStats +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.utils.random_utils import set_seed + +# --------------------------------------------------------------------------- +# Helpers (reuse patterns from tests/policies/test_sac_policy.py) +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def set_random_seed(): + set_seed(42) + + +def _make_sac_config( + state_dim: int = 10, + action_dim: int = 6, + num_discrete_actions: int | None = None, + utd_ratio: int = 1, + policy_update_freq: int = 1, + with_images: bool = False, +) -> SACConfig: + config = SACConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + utd_ratio=utd_ratio, + policy_update_freq=policy_update_freq, + num_discrete_actions=num_discrete_actions, + use_torch_compile=False, + ) + if with_images: + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { + "mean": torch.randn(3, 1, 1).tolist(), + "std": torch.randn(3, 1, 1).abs().tolist(), + } + config.latent_dim = 32 + config.state_encoder_hidden_dim = 32 + config.validate_features() + return config + + +def _make_algorithm( + state_dim: int = 10, + action_dim: int = 6, + utd_ratio: int = 1, + policy_update_freq: int = 1, + num_discrete_actions: int | None = None, + with_images: bool = False, +) -> tuple[SACAlgorithm, SACPolicy]: + sac_cfg = _make_sac_config( + state_dim=state_dim, + action_dim=action_dim, + utd_ratio=utd_ratio, + policy_update_freq=policy_update_freq, + num_discrete_actions=num_discrete_actions, + with_images=with_images, + ) + policy = SACPolicy(config=sac_cfg) + policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers() + return algorithm, policy + + +def _make_batch( + batch_size: int = 4, + state_dim: int = 10, + action_dim: int = 6, + with_images: bool = False, +) -> dict: + obs = {OBS_STATE: torch.randn(batch_size, state_dim)} + next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)} + if with_images: + obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84) + next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84) + return { + ACTION: torch.randn(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": obs, + "next_state": next_obs, + "done": torch.zeros(batch_size), + "complementary_info": {}, + } + + +def _batch_iterator(**batch_kwargs): + """Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator).""" + while True: + yield _make_batch(**batch_kwargs) + + +# =========================================================================== +# Registry / config tests +# =========================================================================== + + +def test_sac_algorithm_config_registered(): + """SACAlgorithmConfig should be discoverable through the registry.""" + assert "sac" in RLAlgorithmConfig.get_known_choices() + cls = RLAlgorithmConfig.get_choice_class("sac") + assert cls is SACAlgorithmConfig + + +def test_sac_algorithm_config_from_policy_config(): + """from_policy_config should copy relevant fields.""" + sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2) + algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) + assert algo_cfg.utd_ratio == 4 + assert algo_cfg.policy_update_freq == 2 + assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm + + +# =========================================================================== +# TrainingStats tests +# =========================================================================== + + +def test_training_stats_defaults(): + stats = TrainingStats() + assert stats.losses == {} + assert stats.grad_norms == {} + assert stats.extra == {} + + +# =========================================================================== +# get_weights +# =========================================================================== + + +def test_get_weights_returns_actor_state_dict(): + algorithm, policy = _make_algorithm() + weights = algorithm.get_weights() + assert "policy" in weights + for key in policy.actor.state_dict(): + assert key in weights["policy"] + assert torch.equal(weights["policy"][key].cpu(), policy.actor.state_dict()[key].cpu()) + + +def test_get_weights_includes_discrete_critic_when_present(): + algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6) + weights = algorithm.get_weights() + assert "discrete_critic" in weights + for key in policy.discrete_critic.state_dict(): + assert key in weights["discrete_critic"] + + +def test_get_weights_excludes_discrete_critic_when_absent(): + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + assert "discrete_critic" not in weights + + +def test_get_weights_are_on_cpu(): + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + for key, tensor in weights["policy"].items(): + assert tensor.device == torch.device("cpu"), f"{key} is not on CPU" + + +# =========================================================================== +# select_action +# =========================================================================== + + +def test_select_action_returns_correct_shape(): + action_dim = 6 + algorithm, _ = _make_algorithm(state_dim=10, action_dim=action_dim) + obs = {OBS_STATE: torch.randn(10)} + action = algorithm.select_action(obs) + assert action.shape == (action_dim,) + + +def test_select_action_with_discrete_critic(): + continuous_dim = 5 + algorithm, _ = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3) + obs = {OBS_STATE: torch.randn(10)} + action = algorithm.select_action(obs) + assert action.shape == (continuous_dim + 1,) + + +# =========================================================================== +# update (single batch, utd_ratio=1) +# =========================================================================== + + +def test_update_returns_training_stats(): + algorithm, _ = _make_algorithm() + stats = algorithm.update(_batch_iterator()) + assert isinstance(stats, TrainingStats) + assert "critic" in stats.losses + assert isinstance(stats.losses["critic"], float) + + +def test_update_populates_actor_and_temperature_losses(): + """With policy_update_freq=1 and step 0, actor/temperature should be updated.""" + algorithm, _ = _make_algorithm(policy_update_freq=1) + stats = algorithm.update(_batch_iterator()) + assert "actor" in stats.losses + assert "temperature" in stats.losses + assert "temperature" in stats.extra + + +@pytest.mark.parametrize("policy_update_freq", [2, 3]) +def test_update_skips_actor_at_non_update_steps(policy_update_freq): + """Actor/temperature should only update when optimization_step % freq == 0.""" + algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq) + it = _batch_iterator() + + # Step 0: should update actor + stats_0 = algorithm.update(it) + assert "actor" in stats_0.losses + + # Step 1: should NOT update actor + stats_1 = algorithm.update(it) + assert "actor" not in stats_1.losses + + +def test_update_increments_optimization_step(): + algorithm, _ = _make_algorithm() + it = _batch_iterator() + assert algorithm.optimization_step == 0 + algorithm.update(it) + assert algorithm.optimization_step == 1 + algorithm.update(it) + assert algorithm.optimization_step == 2 + + +def test_update_with_discrete_critic(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete + assert "discrete_critic" in stats.losses + assert "discrete_critic" in stats.grad_norms + + +# =========================================================================== +# update with UTD ratio > 1 +# =========================================================================== + + +@pytest.mark.parametrize("utd_ratio", [2, 4]) +def test_update_with_utd_ratio(utd_ratio): + algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) + stats = algorithm.update(_batch_iterator()) + assert isinstance(stats, TrainingStats) + assert "critic" in stats.losses + assert algorithm.optimization_step == 1 + + +def test_update_utd_ratio_pulls_utd_batches(): + """next(batch_iterator) should be called exactly utd_ratio times.""" + utd_ratio = 3 + algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) + + call_count = 0 + + def counting_iterator(): + nonlocal call_count + while True: + call_count += 1 + yield _make_batch() + + algorithm.update(counting_iterator()) + assert call_count == utd_ratio + + +def test_update_utd_ratio_3_critic_warmup_changes_weights(): + """With utd_ratio=3, critic weights should change after update (3 critic steps).""" + algorithm, policy = _make_algorithm(utd_ratio=3) + + critic_params_before = {n: p.clone() for n, p in policy.critic_ensemble.named_parameters()} + + algorithm.update(_batch_iterator()) + + changed = False + for n, p in policy.critic_ensemble.named_parameters(): + if not torch.equal(p, critic_params_before[n]): + changed = True + break + assert changed, "Critic weights should have changed after UTD update" + + +# =========================================================================== +# get_observation_features +# =========================================================================== + + +def test_get_observation_features_returns_none_without_frozen_encoder(): + algorithm, _ = _make_algorithm(with_images=False) + obs = {OBS_STATE: torch.randn(4, 10)} + next_obs = {OBS_STATE: torch.randn(4, 10)} + feat, next_feat = algorithm.get_observation_features(obs, next_obs) + assert feat is None + assert next_feat is None + + +# =========================================================================== +# optimization_step setter +# =========================================================================== + + +def test_optimization_step_can_be_set_for_resume(): + algorithm, _ = _make_algorithm() + algorithm.optimization_step = 100 + assert algorithm.optimization_step == 100 + + +# =========================================================================== +# make_algorithm factory +# =========================================================================== + + +def test_make_algorithm_returns_sac_for_sac_policy(): + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + +def test_make_optimizers_creates_expected_keys(): + """make_optimizers() should populate the algorithm with Adam optimizers.""" + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + optimizers = algorithm.make_optimizers() + assert "actor" in optimizers + assert "critic" in optimizers + assert "temperature" in optimizers + assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values()) + assert algorithm.get_optimizers() is optimizers + + +def test_actor_side_no_optimizers(): + """Actor-side usage: no optimizers needed, make_optimizers is not called.""" + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + +def test_make_algorithm_copies_config_fields(): + sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3) + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert algorithm.config.utd_ratio == 5 + assert algorithm.config.policy_update_freq == 3 + + +def test_make_algorithm_raises_for_unknown_type(): + class FakeConfig: + type = "unknown_algo" + + with pytest.raises(ValueError, match="No RLAlgorithmConfig"): + make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo") + + +# =========================================================================== +# load_weights (round-trip with get_weights) +# =========================================================================== + + +def test_load_weights_round_trip(): + """get_weights -> load_weights should restore identical parameters on a fresh policy.""" + algo_src, _ = _make_algorithm(state_dim=10, action_dim=6) + algo_src.update(_batch_iterator()) + + sac_cfg = _make_sac_config(state_dim=10, action_dim=6) + policy_dst = SACPolicy(config=sac_cfg) + algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config) + + weights = algo_src.get_weights() + algo_dst.load_weights(weights, device="cpu") + + for key in weights["policy"]: + assert torch.equal( + algo_dst.policy.actor.state_dict()[key].cpu(), + weights["policy"][key].cpu(), + ), f"Actor param '{key}' mismatch after load_weights" + + +def test_load_weights_round_trip_with_discrete_critic(): + algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + algo_src.update(_batch_iterator(action_dim=7)) + + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + policy_dst = SACPolicy(config=sac_cfg) + algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config) + + weights = algo_src.get_weights() + algo_dst.load_weights(weights, device="cpu") + + for key in weights["discrete_critic"]: + assert torch.equal( + algo_dst.policy.discrete_critic.state_dict()[key].cpu(), + weights["discrete_critic"][key].cpu(), + ), f"Discrete critic param '{key}' mismatch after load_weights" + + +def test_load_weights_ignores_missing_discrete_critic(): + """load_weights should not fail when weights lack discrete_critic on a non-discrete policy.""" + algorithm, _ = _make_algorithm() + weights = {"policy": algorithm.get_weights()["policy"]} + algorithm.load_weights(weights, device="cpu") + + +# =========================================================================== +# TrainingStats generic losses dict +# =========================================================================== + + +def test_training_stats_generic_losses(): + stats = TrainingStats( + losses={"loss_bc": 0.5, "loss_q": 1.2}, + extra={"temperature": 0.1}, + ) + assert stats.losses["loss_bc"] == 0.5 + assert stats.losses["loss_q"] == 1.2 + assert stats.extra["temperature"] == 0.1 + + +# =========================================================================== +# Registry-driven build_algorithm +# =========================================================================== + + +def test_build_algorithm_via_config(): + """SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm.""" + sac_cfg = _make_sac_config(utd_ratio=2) + algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + policy = SACPolicy(config=sac_cfg) + + algorithm = algo_config.build_algorithm(policy) + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.config.utd_ratio == 2 + + +def test_make_algorithm_uses_build_algorithm(): + """make_algorithm should delegate to config.build_algorithm (no hardcoded if/else).""" + sac_cfg = _make_sac_config() + policy = SACPolicy(config=sac_cfg) + algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + assert isinstance(algorithm, SACAlgorithm) diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py new file mode 100644 index 000000000..bbe60c00b --- /dev/null +++ b/tests/rl/test_trainer.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor + +from lerobot.rl.algorithms.base import RLAlgorithm, TrainingStats +from lerobot.rl.trainer import RLTrainer +from lerobot.utils.constants import ACTION, OBS_STATE + + +class _CountingAlgorithm(RLAlgorithm): + def __init__(self): + self.configure_calls = 0 + self.update_calls = 0 + + def select_action(self, observation: dict[str, Tensor]) -> Tensor: + return torch.zeros(1) + + def configure_data_iterator( + self, + data_mixer, + batch_size: int, + *, + async_prefetch: bool = True, + queue_size: int = 2, + ): + self.configure_calls += 1 + return data_mixer.get_iterator( + batch_size=batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + def make_optimizers(self): + return {} + + def update(self, batch_iterator): + self.update_calls += 1 + _ = next(batch_iterator) + return TrainingStats(losses={"dummy": 1.0}) + + def load_weights(self, weights, device="cpu") -> None: + _ = (weights, device) + + +class _SimpleMixer: + def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2): + _ = (async_prefetch, queue_size) + while True: + yield { + "state": {OBS_STATE: torch.randn(batch_size, 3)}, + ACTION: torch.randn(batch_size, 2), + "reward": torch.randn(batch_size), + "next_state": {OBS_STATE: torch.randn(batch_size, 3)}, + "done": torch.zeros(batch_size), + "truncated": torch.zeros(batch_size), + "complementary_info": None, + } + + +def test_trainer_lazy_iterator_lifecycle_and_reset(): + algo = _CountingAlgorithm() + mixer = _SimpleMixer() + trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4, async_prefetch=False) + + # First call builds iterator once. + trainer.training_step() + assert algo.configure_calls == 1 + assert algo.update_calls == 1 + + # Second call reuses existing iterator. + trainer.training_step() + assert algo.configure_calls == 1 + assert algo.update_calls == 2 + + # Explicit reset forces lazy rebuild on next step. + trainer.reset_data_iterator() + trainer.training_step() + assert algo.configure_calls == 2 + assert algo.update_calls == 3 + + +def test_trainer_set_data_mixer_resets_by_default(): + algo = _CountingAlgorithm() + mixer_a = _SimpleMixer() + mixer_b = _SimpleMixer() + trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2, async_prefetch=False) + + trainer.training_step() + assert algo.configure_calls == 1 + + trainer.set_data_mixer(mixer_b, reset=True) + trainer.training_step() + assert algo.configure_calls == 2 + + +def test_algorithm_optimization_step_contract_defaults(): + algo = _CountingAlgorithm() + assert algo.optimization_step == 0 + algo.optimization_step = 11 + assert algo.optimization_step == 11