From e963e5a0c496f4760f0c13f166de159cc33cebd1 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Tue, 12 May 2026 15:49:54 +0200 Subject: [PATCH] RL stack refactoring (#3075) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring * chore: clarify torch.compile disabled note in SACAlgorithm * fix(teleop): keyboard EE teleop not registering special keys and losing intervention state Fixes #2345 Co-authored-by: jpizarrom * fix: remove leftover normalization calls from reward classifier predict_reward Fixes #2355 * fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample() * refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference * perf: remove redundant CPU→GPU→CPU transition move in learner * Fix: add kwargs in reward classifier __init__() * fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer * fix: add try/finally to control_loop to ensure image writer cleanup on exit * fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error * fix: skip tests that require grpc if not available * fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests * fix(tests): skip tests that require grpc if not available * refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages * fix(config): update vision encoder model name to lerobot/resnet10 * fix(sac): clarify torch.compile status * refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity * refactor(sac): simplify optimizer return structure * perf(rl): use async iterators in OnlineOfflineMixer.get_iterator * refactor(sac): decouple algorithm hyperparameters from policy config * update losses names in tests * fix docstring * remove unused type alias * fix test for flat dict structure * refactor(policies): rename policies/sac → policies/gaussian_actor * refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic * perf(observation_processor): add CUDA support for image processing * fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline (cherry picked from commit 9c2af818ff4bfef2603348e0609aa249c3ff62b1) * fix(rl): add time limit processor to environment pipeline (cherry picked from commit cd105f65cb213c4a9c9768926cc3304ca52eb5f4) * fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100 (cherry picked from commit 494f469a2b9dfb792dde6d9d79d8646ef4fcff54) * fix(rl): update neutral gripper action (cherry picked from commit 9c9064e5befe82e981286c6562194f524e16045e) * fix(rl): merge environment and action-processor info in transition processing (cherry picked from commit 30e1886b6466b8753ec41b3016c09a17dd3e960b) * fix(rl): mirror gym_manipulator in actor (cherry picked from commit d2a046dfc5b6f79df34577aa45f32403d897c0a3) * fix(rl): postprocess action in actor (cherry picked from commit c2556439e550ee3fe5bae6060c57cf227101fcaf) * fix(rl): improve action processing for discrete and continuous actions (cherry picked from commit f887ab3f6ace140c4ea6b6186c26473d785b0727) * fix(rl): enhance intervention handling in actor and learner (cherry picked from commit ef8bfffbd72e9d0951de576553f89c7c281315de) * Revert "perf(observation_processor): add CUDA support for image processing" This reverts commit 38b88c414cdc1f53ebaab3211e688fe87522b732. * refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable * refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation * refactor(rl): add type property to RLAlgorithmConfig for better clarity * refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility * refactor(tests): remove grpc import checks from test files for cleaner code * fix(tests): gate RL tests on the `datasets` extra * refactor: simplify docstrings for clarity and conciseness across multiple files * fix(rl): update gripper position key and handle action absence during reset * fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset * refactor: clean up import statements * chore: address reviewer comments * chore: improve visual stats reshaping logic and update docstring for clarity * refactor: enforce mandatory config_class and name attributes in RLAlgorithm * refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer * refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests * refactor: add require_package calls for grpcio and gym-hil in relevant modules * refactor(rl): move grpcio guards to runtime entry points * feat(rl): consolidate HIL-SERL checkpoint into HF-style components Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract `state_dict()` / `load_state_dict()` for critic ensemble, target nets and `log_alpha`, and persist them as a sibling `algorithm/` component next to `pretrained_model/`. Replace the pickled `training_state.pt` with an enriched `training_step.json` carrying `step` and `interaction_step`, so resume restores actor + critics + target nets + temperature + optimizers + RNG + counters from HF-standard files. * refactor(rl): move actor weight-sync wire format from policy to algorithm * refactor(rl): update type hints for learner and actor functions * refactor(rl): hoist grpcio guard to module top in actor/learner * chore(rl): manage import pattern in actor (#3564) * chore(rl): manage import pattern in actor * chore(rl): optional grpc imports in learner; quote grpc ServicerContext types --------- Co-authored-by: Khalil Meftah * update uv.lock * chore(doc): update doc --------- Co-authored-by: jpizarrom Co-authored-by: Steven Palma --- docs/source/hilserl.mdx | 77 +- examples/tutorial/rl/hilserl_example.py | 45 +- pyproject.toml | 2 +- src/lerobot/common/train_utils.py | 1 + src/lerobot/configs/train.py | 7 - src/lerobot/policies/__init__.py | 10 +- src/lerobot/policies/factory.py | 22 +- .../{sac => gaussian_actor}/__init__.py | 8 +- .../configuration_gaussian_actor.py} | 96 +-- .../modeling_gaussian_actor.py} | 495 ++----------- .../processor_gaussian_actor.py} | 10 +- src/lerobot/processor/hil_processor.py | 42 +- src/lerobot/processor/normalize_processor.py | 21 + .../classifier/configuration_classifier.py | 2 +- .../rewards/classifier/modeling_classifier.py | 1 + src/lerobot/rl/__init__.py | 42 +- src/lerobot/rl/actor.py | 191 +++-- src/lerobot/rl/algorithms/__init__.py | 20 + src/lerobot/rl/algorithms/base.py | 207 ++++++ src/lerobot/rl/algorithms/configs.py | 138 ++++ src/lerobot/rl/algorithms/factory.py | 99 +++ src/lerobot/rl/algorithms/sac/__init__.py | 18 + .../rl/algorithms/sac/configuration_sac.py | 99 +++ .../rl/algorithms/sac/sac_algorithm.py | 672 ++++++++++++++++++ src/lerobot/rl/buffer.py | 6 +- src/lerobot/rl/crop_dataset_roi.py | 4 +- src/lerobot/rl/data_sources/__init__.py | 19 + src/lerobot/rl/data_sources/data_mixer.py | 97 +++ src/lerobot/rl/eval_policy.py | 2 +- src/lerobot/rl/gym_manipulator.py | 181 +++-- src/lerobot/rl/learner.py | 432 ++++------- src/lerobot/rl/learner_service.py | 31 +- src/lerobot/rl/train_rl.py | 50 ++ src/lerobot/rl/trainer.py | 101 +++ .../so_follower/robot_kinematic_processor.py | 11 +- .../teleoperators/keyboard/teleop_keyboard.py | 11 +- .../templates/lerobot_modelcard_template.md | 4 +- src/lerobot/types.py | 1 + src/lerobot/utils/constants.py | 1 + src/lerobot/utils/import_utils.py | 1 + ...onfig.py => test_gaussian_actor_config.py} | 55 +- tests/policies/test_gaussian_actor_policy.py | 528 ++++++++++++++ tests/policies/test_sac_policy.py | 546 -------------- ...or.py => test_gaussian_actor_processor.py} | 48 +- tests/processor/test_normalize_processor.py | 21 +- tests/rewards/test_modeling_classifier.py | 13 - tests/rl/test_actor_learner.py | 171 ++++- tests/rl/test_data_mixer.py | 89 +++ tests/rl/test_queue.py | 2 +- tests/rl/test_sac_algorithm.py | 606 ++++++++++++++++ tests/rl/test_trainer.py | 133 ++++ tests/utils/test_process.py | 2 +- tests/utils/test_replay_buffer.py | 1 - uv.lock | 7 + 54 files changed, 3755 insertions(+), 1744 deletions(-) rename src/lerobot/policies/{sac => gaussian_actor}/__init__.py (67%) rename src/lerobot/policies/{sac/configuration_sac.py => gaussian_actor/configuration_gaussian_actor.py} (73%) rename src/lerobot/policies/{sac/modeling_sac.py => gaussian_actor/modeling_gaussian_actor.py} (54%) rename src/lerobot/policies/{sac/processor_sac.py => gaussian_actor/processor_gaussian_actor.py} (92%) create mode 100644 src/lerobot/rl/algorithms/__init__.py create mode 100644 src/lerobot/rl/algorithms/base.py create mode 100644 src/lerobot/rl/algorithms/configs.py create mode 100644 src/lerobot/rl/algorithms/factory.py create mode 100644 src/lerobot/rl/algorithms/sac/__init__.py create mode 100644 src/lerobot/rl/algorithms/sac/configuration_sac.py create mode 100644 src/lerobot/rl/algorithms/sac/sac_algorithm.py create mode 100644 src/lerobot/rl/data_sources/__init__.py create mode 100644 src/lerobot/rl/data_sources/data_mixer.py create mode 100644 src/lerobot/rl/train_rl.py create mode 100644 src/lerobot/rl/trainer.py rename tests/policies/{test_sac_config.py => test_gaussian_actor_config.py} (81%) create mode 100644 tests/policies/test_gaussian_actor_policy.py delete mode 100644 tests/policies/test_sac_policy.py rename tests/processor/{test_sac_processor.py => test_gaussian_actor_processor.py} (89%) create mode 100644 tests/rl/test_data_mixer.py create mode 100644 tests/rl/test_sac_algorithm.py create mode 100644 tests/rl/test_trainer.py diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index 5b9439d51..76e985cfe 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -62,7 +62,7 @@ pip install -e ".[hilserl]" ### Understanding Configuration -The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs: +The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs: ```python @@ -95,6 +95,7 @@ class HILSerlProcessorConfig: class ObservationConfig: add_joint_velocity_to_observation: bool = False # Add joint velocities to state add_current_to_observation: bool = False # Add motor currents to state + add_ee_pose_to_observation: bool = False # Add end-effector pose to state display_cameras: bool = False # Display camera feeds during execution class ImagePreprocessingConfig: @@ -326,14 +327,22 @@ lerobot-find-joint-limits \ Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0] Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0] ``` -3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field +3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`) **Example Configuration** ```json -"end_effector_bounds": { - "max": [0.24, 0.20, 0.10], - "min": [0.16, -0.08, 0.03] +{ + "env": { + "processor": { + "inverse_kinematics": { + "end_effector_bounds": { + "max": [0.24, 0.2, 0.1], + "min": [0.16, -0.08, 0.03] + } + } + } + } } ``` @@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot. HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements. -For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space. +The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are: ```python -class SO100FollowerEndEffectorConfig(SO100FollowerConfig): - """Configuration for the SO100FollowerEndEffector robot.""" +class InverseKinematicsConfig: + """Configuration for inverse kinematics processing.""" - # Default bounds for the end-effector position (in meters) - end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction - default_factory=lambda: { - "min": [-1.0, -1.0, -1.0], # min x, y, z - "max": [1.0, 1.0, 1.0], # max x, y, z - } - ) + urdf_path: str | None = None + target_frame_name: str | None = None + # bounds for the end-effector in x,y,z direction + end_effector_bounds: dict[str, list[float]] | None = None + # maximum step size for the end-effector in x,y,z direction + end_effector_step_sizes: dict[str, float] | None = None - max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at - - end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction - default_factory=lambda: { - "x": 0.02, - "y": 0.02, - "z": 0.02, - } - ) +class HILSerlProcessorConfig: + ... + # maximum gripper position that the gripper will be open at + max_gripper_pos: float | None = 100.0 ``` @@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf **Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device. -The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings. +The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings. **Collecting a Dataset for the reward classifier** -Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards. +Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards. To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig. @@ -658,7 +661,7 @@ Example configuration section for data collection: }, "dataset": { "repo_id": "hf_username/dataset_name", - "dataset_root": "data/your_dataset", + "root": "data/your_dataset", "task": "reward_classifier_task", "num_episodes_to_record": 20, "replay_episode": null, @@ -671,7 +674,7 @@ Example configuration section for data collection: **Reward Classifier Configuration** -The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters: +The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters: - **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`) - **model_type**: `"cnn"` or `"transformer"` @@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c "repo_id": "hf_username/dataset_name", "root": null }, - "policy": { + "reward_model": { "type": "reward_classifier", "model_name": "helper2424/resnet10", "model_type": "cnn", @@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c "dropout_rate": 0.1, "learning_rate": 1e-4, "device": "cuda", - "use_amp": true, "input_features": { "observation.images.front": { "type": "VISUAL", @@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T **Configuration Setup** -Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. +Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`. -1. Configure the policy settings (`type="sac"`, `device`, etc.) -2. Set `dataset` to your cropped dataset -3. Configure environment settings with crop parameters -4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79). -5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. +1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.) +2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`). +3. Set `dataset` to your cropped dataset +4. Configure environment settings with crop parameters +5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79). +6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. **Starting the Learner** @@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during Some configuration values have a disproportionate impact on training stability and speed: -- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. +- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. - **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. - **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index 71b50e97c..f82c9b048 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -4,13 +4,13 @@ from pathlib import Path from queue import Empty, Full import torch -import torch.optim as optim from lerobot.datasets import LeRobotDataset from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig -from lerobot.policies import SACConfig -from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.policies import GaussianActorConfig +from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.rewards.classifier.modeling_classifier import Classifier +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.gym_manipulator import make_robot_env from lerobot.robots.so_follower import SO100FollowerConfig @@ -28,7 +28,7 @@ def run_learner( transitions_queue: mp.Queue, parameters_queue: mp.Queue, shutdown_event: mp.Event, - policy_learner: SACPolicy, + policy_learner: GaussianActorPolicy, online_buffer: ReplayBuffer, offline_buffer: ReplayBuffer, lr: float = 3e-4, @@ -40,8 +40,9 @@ def run_learner( policy_learner.train() policy_learner.to(device) - # Create Adam optimizer from scratch - simple and clean - optimizer = optim.Adam(policy_learner.parameters(), lr=lr) + algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config) + algorithm = SACAlgorithm(policy=policy_learner, config=algo_config) + algorithm.make_optimizers_and_scheduler() print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}") print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}") @@ -83,24 +84,26 @@ def run_learner( else: batch[key] = online_batch[key] - loss, _ = policy_learner.forward(batch) + def batch_iter(b=batch): + while True: + yield b - optimizer.zero_grad() - loss.backward() - optimizer.step() + stats = algorithm.update(batch_iter()) training_step += 1 if training_step % LOG_EVERY == 0: + log_dict = stats.to_log_dict() print( - f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, " + f"[LEARNER] Training step {training_step}, " + f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, " f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}" ) # Send updated parameters to actor every 10 training steps if training_step % SEND_EVERY == 0: try: - state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()} - parameters_queue.put_nowait(state_dict) + weights = algorithm.get_weights() + parameters_queue.put_nowait(weights) print("[LEARNER] Sent updated parameters to actor") except Full: # Missing write due to queue not being consumed (should happen rarely) @@ -113,7 +116,7 @@ def run_actor( transitions_queue: mp.Queue, parameters_queue: mp.Queue, shutdown_event: mp.Event, - policy_actor: SACPolicy, + policy_actor: GaussianActorPolicy, reward_classifier: Classifier, env_cfg: HILSerlRobotEnvConfig, device: torch.device = "mps", @@ -144,15 +147,15 @@ def run_actor( while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set(): try: - new_params = parameters_queue.get_nowait() - policy_actor.load_state_dict(new_params) + new_weights = parameters_queue.get_nowait() + policy_actor.load_state_dict(new_weights) print("[ACTOR] Updated policy parameters from learner") except Empty: # No new updated parameters available from learner, waiting pass - # Get action from policy + # Get action from policy (returns full action: continuous + discrete) policy_obs = make_policy_obs(obs, device=device) - action_tensor = policy_actor.select_action(policy_obs) # predicts a single action + action_tensor = policy_actor.select_action(policy_obs) action = action_tensor.squeeze(0).cpu().numpy() # Step environment @@ -261,14 +264,14 @@ def main(): action_features = hw_to_dataset_features(env.robot.action_features, "action") # Create SAC policy for action selection - policy_cfg = SACConfig( + policy_cfg = GaussianActorConfig( device=device, input_features=obs_features, output_features=action_features, ) - policy_actor = SACPolicy(policy_cfg) - policy_learner = SACPolicy(policy_cfg) + policy_actor = GaussianActorPolicy(policy_cfg) + policy_learner = GaussianActorPolicy(policy_cfg) demonstrations_repo_id = "lerobot/example_hil_serl_dataset" offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id) diff --git a/pyproject.toml b/pyproject.toml index 4deb34034..870f7b62b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,7 +195,7 @@ groot = [ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 3e96e1330..21ee514de 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -99,6 +99,7 @@ def save_checkpoint( optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. + postprocessor: The postprocessor/pipeline to save. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 318821166..388de9437 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -269,10 +269,3 @@ class TrainPipelineConfig(HubMixin): with draccus.config_type("json"): return draccus.parse(cls, config_file, args=cli_args) - - -@dataclass(kw_only=True) -class TrainRLServerPipelineConfig(TrainPipelineConfig): - # NOTE: In RL, we don't need an offline dataset - # TODO: Make `TrainPipelineConfig.dataset` optional - dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 2633d04ad..3a6b8e5d2 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -18,13 +18,13 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .eo1.configuration_eo1 import EO1Config as EO1Config from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors +from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig from .groot.configuration_groot import GrootConfig as GrootConfig from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config from .pretrained import PreTrainedPolicy as PreTrainedPolicy -from .sac.configuration_sac import SACConfig as SACConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .utils import make_robot_action, prepare_observation_for_inference @@ -32,21 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig from .wall_x.configuration_wall_x import WallXConfig as WallXConfig from .xvla.configuration_xvla import XVLAConfig as XVLAConfig -# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here. +# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here. # They have heavy optional dependencies and are loaded lazily via get_policy_class(). -# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy`` +# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy`` __all__ = [ # Configuration classes "ACTConfig", "DiffusionConfig", + "EO1Config", + "GaussianActorConfig", "GrootConfig", "MultiTaskDiTConfig", - "EO1Config", "PI0Config", "PI0FastConfig", "PI05Config", - "SACConfig", "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3609cc7c3..8937bc6ae 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -47,12 +47,12 @@ from lerobot.utils.feature_utils import dataset_to_policy_features from .act.configuration_act import ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config +from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig from .groot.configuration_groot import GrootConfig from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config from .pi05.configuration_pi05 import PI05Config from .pretrained import PreTrainedPolicy -from .sac.configuration_sac import SACConfig from .smolvla.configuration_smolvla import SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig from .utils import validate_visual_features_consistency @@ -88,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x". + "multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -127,10 +127,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .pi05.modeling_pi05 import PI05Policy return PI05Policy - elif name == "sac": - from .sac.modeling_sac import SACPolicy + elif name == "gaussian_actor": + from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy - return SACPolicy + return GaussianActorPolicy elif name == "smolvla": from .smolvla.modeling_smolvla import SmolVLAPolicy @@ -167,7 +167,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", + "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. @@ -191,8 +191,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi05": return PI05Config(**kwargs) - elif policy_type == "sac": - return SACConfig(**kwargs) + elif policy_type == "gaussian_actor": + return GaussianActorConfig(**kwargs) elif policy_type == "smolvla": return SmolVLAConfig(**kwargs) elif policy_type == "groot": @@ -365,10 +365,10 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, SACConfig): - from .sac.processor_sac import make_sac_pre_post_processors + elif isinstance(policy_cfg, GaussianActorConfig): + from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors - processors = make_sac_pre_post_processors( + processors = make_gaussian_actor_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) diff --git a/src/lerobot/policies/sac/__init__.py b/src/lerobot/policies/gaussian_actor/__init__.py similarity index 67% rename from src/lerobot/policies/sac/__init__.py rename to src/lerobot/policies/gaussian_actor/__init__.py index cf5f149f3..c3c5855ac 100644 --- a/src/lerobot/policies/sac/__init__.py +++ b/src/lerobot/policies/gaussian_actor/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration_sac import SACConfig -from .modeling_sac import SACPolicy -from .processor_sac import make_sac_pre_post_processors +from .configuration_gaussian_actor import GaussianActorConfig +from .modeling_gaussian_actor import GaussianActorPolicy +from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors -__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"] +__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"] diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py similarity index 73% rename from src/lerobot/policies/sac/configuration_sac.py rename to src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py index db0a77672..e51653992 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py @@ -1,4 +1,4 @@ -# !/usr/bin/env python +#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. # All rights reserved. @@ -75,18 +75,19 @@ class PolicyConfig: init_final: float = 0.05 -@PreTrainedConfig.register_subclass("sac") +@PreTrainedConfig.register_subclass("gaussian_actor") @dataclass -class SACConfig(PreTrainedConfig): - """Soft Actor-Critic (SAC) configuration. +class GaussianActorConfig(PreTrainedConfig): + """Gaussian actor configuration. - SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy - reinforcement learning framework. It learns a policy and a Q-function simultaneously - using experience collected from the environment. + This configures the policy-side (actor + observation encoder) of a Gaussian + policy, as used by SAC and related maximum-entropy continuous-control algorithms. + By default the actor output is a tanh-squashed diagonal Gaussian + (``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via + ``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update + logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``). - This configuration class contains all the parameters needed to define a SAC agent, - including network architectures, optimization settings, and algorithm-specific - hyperparameters. + CLI: ``--policy.type=gaussian_actor``. """ # Mapping of feature types to normalization modes @@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig): device: str = "cpu" # Device to store the model on storage_device: str = "cpu" - # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10) + # Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10) vision_encoder_name: str | None = None # Whether to freeze the vision encoder during training freeze_vision_encoder: bool = True @@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig): # Dimension of the image embedding pooling image_embedding_pooling_dim: int = 8 - # Training parameter + # Encoder architecture + # Hidden dimension size for the state encoder + state_encoder_hidden_dim: int = 256 + # Dimension of the latent space + latent_dim: int = 256 + + # Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig) # Number of steps for online training online_steps: int = 1000000 # Capacity of the online replay buffer @@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig): async_prefetch: bool = False # Number of steps before learning starts online_step_before_learning: int = 100 - # Frequency of policy updates - policy_update_freq: int = 1 - # SAC algorithm parameters - # Discount factor for the SAC algorithm - discount: float = 0.99 - # Initial temperature value - temperature_init: float = 1.0 - # Number of critics in the ensemble - num_critics: int = 2 - # Number of subsampled critics for training - num_subsample_critics: int | None = None - # Learning rate for the critic network - critic_lr: float = 3e-4 - # Learning rate for the actor network - actor_lr: float = 3e-4 - # Learning rate for the temperature parameter - temperature_lr: float = 3e-4 - # Weight for the critic target update - critic_target_update_weight: float = 0.005 - # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1) - utd_ratio: int = 1 - # Hidden dimension size for the state encoder - state_encoder_hidden_dim: int = 256 - # Dimension of the latent space - latent_dim: int = 256 - # Target entropy for the SAC algorithm - target_entropy: float | None = None - # Whether to use backup entropy for the SAC algorithm - use_backup_entropy: bool = True - # Gradient clipping norm for the SAC algorithm - grad_clip_norm: float = 40.0 - - # Network configuration - # Configuration for the critic network architecture - critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) - # Configuration for the actor network architecture - actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) - # Configuration for the policy parameters - policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - # Configuration for the discrete critic network - discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig). # Configuration for actor-learner architecture actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) # Configuration for concurrency settings (you can use threads or processes for the actor and learner) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) - # Optimizations - use_torch_compile: bool = True + # Network architecture + # Configuration for the actor network architecture + actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + # Configuration for the policy parameters (Gaussian head) + policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + # Configuration for the discrete critic network + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) def __post_init__(self): super().__post_init__() - # Any validation specific to SAC configuration + # Any validation specific to GaussianActor configuration def get_optimizer_preset(self) -> MultiAdamConfig: + # Default learning rate used to satisfy the abstract ``get_optimizer_preset()`` + # contract from ``PreTrainedConfig``. The actual optimizers used during RL + # training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from + # ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass + # this preset. + default_lr = 3e-4 return MultiAdamConfig( weight_decay=0.0, optimizer_groups={ - "actor": {"lr": self.actor_lr}, - "critic": {"lr": self.critic_lr}, - "temperature": {"lr": self.temperature_lr}, + "actor": {"lr": default_lr}, + "critic": {"lr": default_lr}, + "temperature": {"lr": default_lr}, }, ) diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py similarity index 54% rename from src/lerobot/policies/sac/modeling_sac.py rename to src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py index cc7030ce2..a833d01cc 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py @@ -15,16 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from collections.abc import Callable from dataclasses import asdict -from typing import Literal -import einops -import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution @@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE from ..pretrained import PreTrainedPolicy from ..utils import get_device_from_parameters -from .configuration_sac import SACConfig, is_image_feature +from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension -class SACPolicy( +class GaussianActorPolicy( PreTrainedPolicy, ): - config_class = SACConfig - name = "sac" + config_class = GaussianActorConfig + name = "gaussian_actor" def __init__( self, - config: SACConfig | None = None, + config: GaussianActorConfig | None = None, ): super().__init__(config) config.validate_features() @@ -54,9 +49,8 @@ class SACPolicy( # Determine action dimension and initialize all components continuous_action_dim = config.output_features[ACTION].shape[0] self._init_encoders() - self._init_critics(continuous_action_dim) self._init_actor(continuous_action_dim) - self._init_temperature() + self._init_discrete_critic() def get_optim_params(self) -> dict: optim_params = { @@ -65,11 +59,7 @@ class SACPolicy( for n, p in self.actor.named_parameters() if not n.startswith("encoder") or not self.shared_encoder ], - "critic": self.critic_ensemble.parameters(), - "temperature": self.log_alpha, } - if self.config.num_discrete_actions is not None: - optim_params["discrete_critic"] = self.discrete_critic.parameters() return optim_params def reset(self): @@ -79,7 +69,9 @@ class SACPolicy( @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" - raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!") + raise NotImplementedError( + "GaussianActorPolicy does not support action chunking. It returns single actions!" + ) @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -92,360 +84,43 @@ class SACPolicy( actions, _, _ = self.actor(batch, observations_features) if self.config.num_discrete_actions is not None: - discrete_action_value = self.discrete_critic(batch, observations_features) - discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + if self.discrete_critic is not None: + discrete_action_value = self.discrete_critic(batch, observations_features) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + else: + discrete_action = torch.ones( + (*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype + ) actions = torch.cat([actions, discrete_action], dim=-1) return actions - def critic_forward( - self, - observations: dict[str, Tensor], - actions: Tensor, - use_target: bool = False, - observation_features: Tensor | None = None, - ) -> Tensor: - """Forward pass through a critic network ensemble + def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]: + """Actor forward pass: sample actions and return log-probabilities. Args: - observations: Dictionary of observations - actions: Action tensor - use_target: If True, use target critics, otherwise use ensemble critics + batch: A flat observation dict, or a training dict containing + ``"state"`` (observations) and optionally ``"observation_feature"`` + (pre-computed encoder features). Returns: - Tensor of Q-values from all critics + Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors. """ - - critics = self.critic_target if use_target else self.critic_ensemble - q_values = critics(observations, actions, observation_features) - return q_values - - def discrete_critic_forward( - self, observations, use_target=False, observation_features=None - ) -> torch.Tensor: - """Forward pass through a discrete critic network - - Args: - observations: Dictionary of observations - use_target: If True, use target critics, otherwise use ensemble critics - observation_features: Optional pre-computed observation features to avoid recomputing encoder output - - Returns: - Tensor of Q-values from the discrete critic network - """ - discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic - q_values = discrete_critic(observations, observation_features) - return q_values - - def forward( - self, - batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", - ) -> dict[str, Tensor]: - """Compute the loss for the given model - - Args: - batch: Dictionary containing: - - action: Action tensor - - reward: Reward tensor - - state: Observations tensor dict - - next_state: Next observations tensor dict - - done: Done mask tensor - - observation_feature: Optional pre-computed observation features - - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") - - Returns: - The computed loss tensor - """ - # Extract common components from batch - actions: Tensor = batch[ACTION] - observations: dict[str, Tensor] = batch["state"] - observation_features: Tensor = batch.get("observation_feature") - - if model == "critic": - # Extract critic-specific components - rewards: Tensor = batch["reward"] - next_observations: dict[str, Tensor] = batch["next_state"] - done: Tensor = batch["done"] - next_observation_features: Tensor = batch.get("next_observation_feature") - - loss_critic = self.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - - return {"loss_critic": loss_critic} - - if model == "discrete_critic" and self.config.num_discrete_actions is not None: - # Extract critic-specific components - rewards: Tensor = batch["reward"] - next_observations: dict[str, Tensor] = batch["next_state"] - done: Tensor = batch["done"] - next_observation_features: Tensor = batch.get("next_observation_feature") - complementary_info = batch.get("complementary_info") - loss_discrete_critic = self.compute_loss_discrete_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - complementary_info=complementary_info, - ) - return {"loss_discrete_critic": loss_discrete_critic} - if model == "actor": - return { - "loss_actor": self.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) - } - - if model == "temperature": - return { - "loss_temperature": self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) - } - - raise ValueError(f"Unknown model type: {model}") - - def update_target_networks(self): - """Update target networks with exponential moving average""" - for target_param, param in zip( - self.critic_target.parameters(), - self.critic_ensemble.parameters(), - strict=True, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) - if self.config.num_discrete_actions is not None: - for target_param, param in zip( - self.discrete_critic_target.parameters(), - self.discrete_critic.parameters(), - strict=True, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) - - @property - def temperature(self) -> float: - """Return the current temperature value, always in sync with log_alpha.""" - return self.log_alpha.exp().item() - - def compute_loss_critic( - self, - observations, - actions, - rewards, - next_observations, - done, - observation_features: Tensor | None = None, - next_observation_features: Tensor | None = None, - ) -> Tensor: - with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) - - # 2- compute q targets - q_targets = self.critic_forward( - observations=next_observations, - actions=next_action_preds, - use_target=True, - observation_features=next_observation_features, - ) - - # subsample critics to prevent overfitting if use high UTD (update to date) - # TODO: Get indices before forward pass to avoid unnecessary computation - if self.config.num_subsample_critics is not None: - indices = torch.randperm(self.config.num_critics) - indices = indices[: self.config.num_subsample_critics] - q_targets = q_targets[indices] - - # critics subsample size - min_q, _ = q_targets.min(dim=0) # Get values from min operation - if self.config.use_backup_entropy: - min_q = min_q - (self.temperature * next_log_probs) - - td_target = rewards + (1 - done) * self.config.discount * min_q - - # 3- compute predicted qs - if self.config.num_discrete_actions is not None: - # NOTE: We only want to keep the continuous action part - # In the buffer we have the full action space (continuous + discrete) - # We need to split them before concatenating them in the critic forward - actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - q_preds = self.critic_forward( - observations=observations, - actions=actions, - use_target=False, - observation_features=observation_features, - ) - - # 4- Calculate loss - # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) - # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up - critics_loss = ( - F.mse_loss( - input=q_preds, - target=td_target_duplicate, - reduction="none", - ).mean(dim=1) - ).sum() - return critics_loss - - def compute_loss_discrete_critic( - self, - observations, - actions, - rewards, - next_observations, - done, - observation_features=None, - next_observation_features=None, - complementary_info=None, - ): - # NOTE: We only want to keep the discrete action part - # In the buffer we have the full action space (continuous + discrete) - # We need to split them before concatenating them in the critic forward - actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() - actions_discrete = torch.round(actions_discrete) - actions_discrete = actions_discrete.long() - - discrete_penalties: Tensor | None = None - if complementary_info is not None: - discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") - - with torch.no_grad(): - # For DQN, select actions using online network, evaluate with target network - next_discrete_qs = self.discrete_critic_forward( - next_observations, use_target=False, observation_features=next_observation_features - ) - best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) - - # Get target Q-values from target network - target_next_discrete_qs = self.discrete_critic_forward( - observations=next_observations, - use_target=True, - observation_features=next_observation_features, - ) - - # Use gather to select Q-values for best actions - target_next_discrete_q = torch.gather( - target_next_discrete_qs, dim=1, index=best_next_discrete_action - ).squeeze(-1) - - # Compute target Q-value with Bellman equation - rewards_discrete = rewards - if discrete_penalties is not None: - rewards_discrete = rewards + discrete_penalties - target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q - - # Get predicted Q-values for current observations - predicted_discrete_qs = self.discrete_critic_forward( - observations=observations, use_target=False, observation_features=observation_features - ) - - # Use gather to select Q-values for taken actions - predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) - - # Compute MSE loss between predicted and target Q-values - discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) - return discrete_critic_loss - - def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: - """Compute the temperature loss""" - # calculate temperature loss - with torch.no_grad(): - _, log_probs, _ = self.actor(observations, observation_features) - temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() - return temperature_loss - - def compute_loss_actor( - self, - observations, - observation_features: Tensor | None = None, - ) -> Tensor: - actions_pi, log_probs, _ = self.actor(observations, observation_features) - - q_preds = self.critic_forward( - observations=observations, - actions=actions_pi, - use_target=False, - observation_features=observation_features, - ) - min_q_preds = q_preds.min(dim=0)[0] - - actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() - return actor_loss + observations = batch.get("state", batch) + observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None + actions, log_probs, means = self.actor(observations, observation_features) + return {"action": actions, "log_prob": log_probs, "action_mean": means} def _init_encoders(self): """Initialize shared or separate encoders for actor and critic.""" self.shared_encoder = self.config.shared_encoder - self.encoder_critic = SACObservationEncoder(self.config) + self.encoder_critic = GaussianActorObservationEncoder(self.config) self.encoder_actor = ( - self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config) + self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config) ) - def _init_critics(self, continuous_action_dim): - """Build critic ensemble, targets, and optional discrete critic.""" - heads = [ - CriticHead( - input_dim=self.encoder_critic.output_dim + continuous_action_dim, - **asdict(self.config.critic_network_kwargs), - ) - for _ in range(self.config.num_critics) - ] - self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) - target_heads = [ - CriticHead( - input_dim=self.encoder_critic.output_dim + continuous_action_dim, - **asdict(self.config.critic_network_kwargs), - ) - for _ in range(self.config.num_critics) - ] - self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) - self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) - - if self.config.use_torch_compile: - self.critic_ensemble = torch.compile(self.critic_ensemble) - self.critic_target = torch.compile(self.critic_target) - - if self.config.num_discrete_actions is not None: - self._init_discrete_critics() - - def _init_discrete_critics(self): - """Build discrete discrete critic ensemble and target networks.""" - self.discrete_critic = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) - self.discrete_critic_target = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) - - # TODO: (maractingi, azouitine) Compile the discrete critic - self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) - def _init_actor(self, continuous_action_dim): - """Initialize policy actor network and default target entropy.""" + """Initialize policy actor network.""" # NOTE: The actor select only the continuous action part self.actor = Policy( encoder=self.encoder_actor, @@ -455,21 +130,25 @@ class SACPolicy( **asdict(self.config.policy_kwargs), ) - self.target_entropy = self.config.target_entropy - if self.target_entropy is None: - dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) - self.target_entropy = -np.prod(dim) / 2 + def _init_discrete_critic(self) -> None: + """Initialize discrete critic network.""" + if self.config.num_discrete_actions is None: + self.discrete_critic = None + return - def _init_temperature(self) -> None: - """Set up temperature parameter (log_alpha).""" - temp_init = self.config.temperature_init - self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + # TODO(Khalil): Compile the discrete critic + self.discrete_critic = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) -class SACObservationEncoder(nn.Module): +class GaussianActorObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" - def __init__(self, config: SACConfig) -> None: + def __init__(self, config: GaussianActorConfig) -> None: super().__init__() self.config = config self._init_image_layers() @@ -677,84 +356,6 @@ class MLP(nn.Module): return self.net(x) -class CriticHead(nn.Module): - def __init__( - self, - input_dim: int, - hidden_dims: list[int], - activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), - activate_final: bool = False, - dropout_rate: float | None = None, - init_final: float | None = None, - final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, - ): - super().__init__() - self.net = MLP( - input_dim=input_dim, - hidden_dims=hidden_dims, - activations=activations, - activate_final=activate_final, - dropout_rate=dropout_rate, - final_activation=final_activation, - ) - self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) - if init_final is not None: - nn.init.uniform_(self.output_layer.weight, -init_final, init_final) - nn.init.uniform_(self.output_layer.bias, -init_final, init_final) - else: - orthogonal_init()(self.output_layer.weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.output_layer(self.net(x)) - - -class CriticEnsemble(nn.Module): - """ - CriticEnsemble wraps multiple CriticHead modules into an ensemble. - - Args: - encoder (SACObservationEncoder): encoder for observations. - ensemble (List[CriticHead]): list of critic heads. - init_final (float | None): optional initializer scale for final layers. - - Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. - """ - - def __init__( - self, - encoder: SACObservationEncoder, - ensemble: list[CriticHead], - init_final: float | None = None, - ): - super().__init__() - self.encoder = encoder - self.init_final = init_final - self.critics = nn.ModuleList(ensemble) - - def forward( - self, - observations: dict[str, torch.Tensor], - actions: torch.Tensor, - observation_features: torch.Tensor | None = None, - ) -> torch.Tensor: - device = get_device_from_parameters(self) - # Move each tensor in observations to device - observations = {k: v.to(device) for k, v in observations.items()} - - obs_enc = self.encoder(observations, cache=observation_features) - - inputs = torch.cat([obs_enc, actions], dim=-1) - - # Loop through critics and collect outputs - q_values = [] - for critic in self.critics: - q_values.append(critic(inputs)) - - # Stack outputs to match expected shape [num_critics, batch_size] - q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) - return q_values - - class DiscreteCritic(nn.Module): def __init__( self, @@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module): class Policy(nn.Module): def __init__( self, - encoder: SACObservationEncoder, + encoder: GaussianActorObservationEncoder, network: nn.Module, action_dim: int, std_min: float = -5, @@ -811,7 +412,7 @@ class Policy(nn.Module): encoder_is_shared: bool = False, ): super().__init__() - self.encoder: SACObservationEncoder = encoder + self.encoder: GaussianActorObservationEncoder = encoder self.network = network self.action_dim = action_dim self.std_min = std_min @@ -885,7 +486,7 @@ class Policy(nn.Module): class DefaultImageEncoder(nn.Module): - def __init__(self, config: SACConfig): + def __init__(self, config: GaussianActorConfig): super().__init__() image_key = next(key for key in config.input_features if is_image_feature(key)) self.image_enc_layers = nn.Sequential( @@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module): class PretrainedImageEncoder(nn.Module): - def __init__(self, config: SACConfig): + def __init__(self, config: GaussianActorConfig): super().__init__() self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) - def _load_pretrained_vision_encoder(self, config: SACConfig): + def _load_pretrained_vision_encoder(self, config: GaussianActorConfig): """Set up CNN encoder""" from transformers import AutoModel diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/gaussian_actor/processor_gaussian_actor.py similarity index 92% rename from src/lerobot/policies/sac/processor_sac.py rename to src/lerobot/policies/gaussian_actor/processor_gaussian_actor.py index 3409307c2..1e930d178 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/gaussian_actor/processor_gaussian_actor.py @@ -32,18 +32,18 @@ from lerobot.processor import ( ) from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME -from .configuration_sac import SACConfig +from .configuration_gaussian_actor import GaussianActorConfig -def make_sac_pre_post_processors( - config: SACConfig, +def make_gaussian_actor_pre_post_processors( + config: GaussianActorConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: """ - Constructs pre-processor and post-processor pipelines for the SAC policy. + Constructs pre-processor and post-processor pipelines for the Gaussian actor policy. The pre-processing pipeline prepares input data for the model by: 1. Renaming features to match pretrained configurations. @@ -56,7 +56,7 @@ def make_sac_pre_post_processors( 2. Unnormalizing the output features to their original scale. Args: - config: The configuration object for the SAC policy. + config: The configuration object for the tanh-Gaussian policy. dataset_stats: A dictionary of statistics for normalization. Returns: diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 49dbb8106..e7351827b 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -4,7 +4,6 @@ # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -# You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -321,6 +320,7 @@ class GymHILAdapterProcessorStep(ProcessorStep): This step normalizes the `transition` object by: 1. Copying `teleop_action` from `info` to `complementary_data`. 2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key). + 3. Copying `discrete_penalty` from `info` to `complementary_data`. """ def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -330,6 +330,9 @@ class GymHILAdapterProcessorStep(ProcessorStep): if TELEOP_ACTION_KEY in info: complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY] + if DISCRETE_PENALTY_KEY in info: + complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY] + if "is_intervention" in info: info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"] @@ -348,18 +351,24 @@ class GymHILAdapterProcessorStep(ProcessorStep): @ProcessorStepRegistry.register("gripper_penalty_processor") class GripperPenaltyProcessorStep(ProcessorStep): """ - Applies a penalty for inefficient gripper usage. + Applies a small per-transition cost on the discrete gripper action. - This step penalizes actions that attempt to close an already closed gripper or - open an already open one, based on position thresholds. + Fires only when the commanded action would actually transition the gripper + from one extreme to the other (close-while-open or open-while-closed). + This discourages gripper oscillation while leaving "stay" and saturating-further + commands unpenalized. Attributes: penalty: The negative reward value to apply. max_gripper_pos: The maximum position value for the gripper, used for normalization. + open_threshold: Normalized state below which the gripper is considered "open". + closed_threshold: Normalized state above which the gripper is considered "closed". """ - penalty: float = -0.01 + penalty: float = -0.02 max_gripper_pos: float = 30.0 + open_threshold: float = 0.1 + closed_threshold: float = 0.9 def __call__(self, transition: EnvTransition) -> EnvTransition: """ @@ -379,11 +388,15 @@ class GripperPenaltyProcessorStep(ProcessorStep): if raw_joint_positions is None: return new_transition - current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None) + current_gripper_pos = raw_joint_positions.get(f"{GRIPPER_KEY}.pos", None) if current_gripper_pos is None: return new_transition - # Gripper action is a PolicyAction at this stage + # During reset, the transition may not carry any action yet. + if action is None: + return new_transition + + # Gripper action is expected as the last action dimension. gripper_action = action[-1].item() gripper_action_normalized = gripper_action / self.max_gripper_pos @@ -391,9 +404,13 @@ class GripperPenaltyProcessorStep(ProcessorStep): gripper_state_normalized = current_gripper_pos / self.max_gripper_pos # Calculate penalty boolean as in original - gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( - gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 - ) + # - currently open AND target is closed -> close transition + # - currently closed AND target is open -> open transition + is_open = gripper_state_normalized < self.open_threshold + is_closed = gripper_state_normalized > self.closed_threshold + cmd_close = gripper_action_normalized > self.closed_threshold + cmd_open = gripper_action_normalized < self.open_threshold + gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open) gripper_penalty = self.penalty * int(gripper_penalty_bool) @@ -409,11 +426,14 @@ class GripperPenaltyProcessorStep(ProcessorStep): Returns the configuration of the step for serialization. Returns: - A dictionary containing the penalty value and max gripper position. + A dictionary containing the penalty value, max gripper position, + and the open/closed thresholds. """ return { "penalty": self.penalty, "max_gripper_pos": self.max_gripper_pos, + "open_threshold": self.open_threshold, + "closed_threshold": self.closed_threshold, } def reset(self) -> None: diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 7516c7b47..1649b4b31 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -134,6 +134,24 @@ class _NormalizationMixin: if self.dtype is None: self.dtype = torch.float32 self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() + + def _reshape_visual_stats(self) -> None: + """Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting. + + No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats` + (already ``(C, 1, 1)``). Needed by RL training, which can start without + a dataset and supplies stats manually via JSON config. + """ + for key, feature in self.features.items(): + if feature.type != FeatureType.VISUAL: + continue + if key not in self._tensor_stats: + continue + for stat_name, stat_tensor in self._tensor_stats[key].items(): + if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1: + continue + self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) def to( self, device: torch.device | str | None = None, dtype: torch.dtype | None = None @@ -152,6 +170,7 @@ class _NormalizationMixin: if dtype is not None: self.dtype = dtype self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() return self def state_dict(self) -> dict[str, Tensor]: @@ -201,6 +220,7 @@ class _NormalizationMixin: # Don't load from state_dict, keep the explicitly provided stats # But ensure _tensor_stats is properly initialized self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + self._reshape_visual_stats() return # Normal behavior: load stats from state_dict @@ -211,6 +231,7 @@ class _NormalizationMixin: self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( dtype=torch.float32, device=self.device ) + self._reshape_visual_stats() # Reconstruct the original stats dict from tensor stats for compatibility with to() method # and other functions that rely on self.stats diff --git a/src/lerobot/rewards/classifier/configuration_classifier.py b/src/lerobot/rewards/classifier/configuration_classifier.py index a618a2cf7..f1ccaacc6 100644 --- a/src/lerobot/rewards/classifier/configuration_classifier.py +++ b/src/lerobot/rewards/classifier/configuration_classifier.py @@ -30,7 +30,7 @@ class RewardClassifierConfig(RewardModelConfig): latent_dim: int = 256 image_embedding_pooling_dim: int = 8 dropout_rate: float = 0.1 - model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading. + model_name: str = "lerobot/resnet10" device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 diff --git a/src/lerobot/rewards/classifier/modeling_classifier.py b/src/lerobot/rewards/classifier/modeling_classifier.py index bedfffbe9..1d8057135 100644 --- a/src/lerobot/rewards/classifier/modeling_classifier.py +++ b/src/lerobot/rewards/classifier/modeling_classifier.py @@ -105,6 +105,7 @@ class Classifier(PreTrainedRewardModel): def __init__( self, config: RewardClassifierConfig, + **kwargs, ): from transformers import AutoModel diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py index 6a7c750d3..8b2d18c54 100644 --- a/src/lerobot/rl/__init__.py +++ b/src/lerobot/rl/__init__.py @@ -12,23 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Reinforcement learning modules. +"""Reinforcement learning modules. -Requires: ``pip install 'lerobot[hilserl]'`` - -Available modules (import directly):: - - from lerobot.rl.actor import ... - from lerobot.rl.learner import ... - from lerobot.rl.learner_service import ... - from lerobot.rl.buffer import ... - from lerobot.rl.eval_policy import ... - from lerobot.rl.gym_manipulator import ... +Distributed actor / learner entry points (``actor``, ``learner``, +``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms, +buffer, data sources and trainer are gRPC-free and usable standalone. """ -from lerobot.utils.import_utils import require_package +from .algorithms.base import RLAlgorithm as RLAlgorithm +from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats +from .algorithms.factory import ( + make_algorithm as make_algorithm, + make_algorithm_config as make_algorithm_config, +) +from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig +from .buffer import ReplayBuffer as ReplayBuffer +from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer +from .trainer import RLTrainer as RLTrainer -require_package("grpcio", extra="hilserl", import_name="grpc") - -__all__: list[str] = [] +__all__ = [ + "RLAlgorithm", + "RLAlgorithmConfig", + "TrainingStats", + "make_algorithm", + "make_algorithm_config", + "SACAlgorithmConfig", + "RLTrainer", + "ReplayBuffer", + "DataMixer", + "OnlineOfflineMixer", +] diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index eab527250..bfc7f1882 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -49,39 +49,53 @@ https://github.com/michel-aractingi/lerobot-hilserl-guide import logging import os import time +from collections.abc import Generator from functools import lru_cache from queue import Empty +from typing import TYPE_CHECKING, Any + +from lerobot.utils.import_utils import _grpc_available, require_package + +if TYPE_CHECKING or _grpc_available: + import grpc + + from lerobot.transport import services_pb2, services_pb2_grpc + from lerobot.transport.utils import ( + bytes_to_state_dict, + grpc_channel_options, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, + ) +else: + grpc = None + services_pb2 = None + services_pb2_grpc = None + bytes_to_state_dict = None + grpc_channel_options = None + python_object_to_bytes = None + receive_bytes_in_chunks = None + send_bytes_in_chunks = None + transitions_to_bytes = None -import grpc import torch from torch import nn -from torch.multiprocessing import Event, Queue +from torch.multiprocessing import Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies import make_policy -from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.policies import make_policy, make_pre_post_processors +from lerobot.processor import TransitionKey from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents -from lerobot.transport import services_pb2, services_pb2_grpc -from lerobot.transport.utils import ( - bytes_to_state_dict, - grpc_channel_options, - python_object_to_bytes, - receive_bytes_in_chunks, - send_bytes_in_chunks, - transitions_to_bytes, -) -from lerobot.types import TransitionKey from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( Transition, - move_state_dict_to_device, move_transition_to_device, ) from lerobot.utils.utils import ( @@ -89,19 +103,24 @@ from lerobot.utils.utils import ( init_logging, ) +from .algorithms.base import RLAlgorithm +from .algorithms.factory import make_algorithm from .gym_manipulator import ( - create_transition, make_processors, make_robot_env, + reset_and_build_transition, step_env_and_process_transition, ) from .queue import get_last_item_from_queue +from .train_rl import TrainRLServerPipelineConfig # Main entry point @parser.wrap() def actor_cli(cfg: TrainRLServerPipelineConfig): + # Fail fast with a friendly error if the optional ``hilserl`` extra is missing. + require_package("grpcio", extra="hilserl", import_name="grpc") cfg.validate() display_pid = False if not use_threads(cfg): @@ -212,7 +231,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig): def act_with_policy( cfg: TrainRLServerPipelineConfig, - shutdown_event: any, # Event, + shutdown_event: Any, # Event parameters_queue: Queue, transitions_queue: Queue, interactions_queue: Queue, @@ -252,22 +271,24 @@ def act_with_policy( logging.info("make_policy") ### Instantiate the policy in both the actor and learner processes - ### To avoid sending a SACPolicy object through the port, we create a policy instance + ### To avoid sending a policy object through the port, we create a policy instance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters - policy: SACPolicy = make_policy( + policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) - policy = policy.eval() + policy = policy.to(device).eval() assert isinstance(policy, nn.Module) - obs, info = online_env.reset() - env_processor.reset() - action_processor.reset() + # Build the algorithm + algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy) - # Process initial observation - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + dataset_stats=cfg.policy.dataset_stats, + ) + + transition = reset_and_build_transition(online_env, env_processor, action_processor) # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 @@ -291,8 +312,17 @@ def act_with_policy( # Time policy inference and check if it meets FPS requirement with policy_timer: - # Extract observation from transition for policy - action = policy.select_action(batch=observation) + normalized_observation = preprocessor.process_observation(observation) + action = policy.select_action(batch=normalized_observation) + # Unnormalize only the continuous part. + if cfg.policy.num_discrete_actions is not None: + continuous_action = postprocessor.process_action(action[..., :-1]) + discrete_action = action[..., -1:].to( + device=continuous_action.device, dtype=continuous_action.dtype + ) + action = torch.cat([continuous_action, discrete_action], dim=-1) + else: + action = postprocessor.process_action(action) policy_fps = policy_timer.fps_last log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) @@ -326,7 +356,8 @@ def act_with_policy( # Check for intervention from transition info intervention_info = new_transition[TransitionKey.INFO] - if intervention_info.get(TeleopEvents.IS_INTERVENTION, False): + is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False)) + if is_intervention: episode_intervention = True episode_intervention_steps += 1 @@ -334,6 +365,7 @@ def act_with_policy( "discrete_penalty": torch.tensor( [new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)] ), + TeleopEvents.IS_INTERVENTION.value: is_intervention, } # Create transition for learner (convert to old format) list_transition_to_send_to_learner.append( @@ -354,7 +386,7 @@ def act_with_policy( if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) + update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -390,14 +422,7 @@ def act_with_policy( episode_intervention_steps = 0 episode_total_steps = 0 - # Reset environment and processors - obs, info = online_env.reset() - env_processor.reset() - action_processor.reset() - - # Process initial observation - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + transition = reset_and_build_transition(online_env, env_processor, action_processor) if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time @@ -408,10 +433,10 @@ def act_with_policy( def establish_learner_connection( - stub: services_pb2_grpc.LearnerServiceStub, - shutdown_event: Event, # type: ignore + stub: "services_pb2_grpc.LearnerServiceStub", + shutdown_event: Any, # Event attempts: int = 30, -): +) -> bool: """Establish a connection with the learner. Args: @@ -441,12 +466,14 @@ def establish_learner_connection( def learner_service_client( host: str = "127.0.0.1", port: int = 50051, -) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: - """ - Returns a client for the learner service. +) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]": + """Return a client for the learner service. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. So we need to create only one client and reuse it. + + Returns: + tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel. """ channel = grpc.insecure_channel( @@ -461,16 +488,18 @@ def learner_service_client( def receive_policy( cfg: TrainRLServerPipelineConfig, parameters_queue: Queue, - shutdown_event: Event, # type: ignore - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, -): + shutdown_event: Any, # Event + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, +) -> None: """Receive parameters from the learner. Args: cfg (TrainRLServerPipelineConfig): The configuration for the actor. parameters_queue (Queue): The queue to receive the parameters. shutdown_event (Event): The event to check if the process should shutdown. + learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub. + grpc_channel (grpc.Channel | None): Optional pre-created channel. """ logging.info("[ACTOR] Start receiving parameters from the Learner") if not use_threads(cfg): @@ -513,12 +542,11 @@ def receive_policy( def send_transitions( cfg: TrainRLServerPipelineConfig, transitions_queue: Queue, - shutdown_event: any, # Event, - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, -) -> services_pb2.Empty: - """ - Sends transitions to the learner. + shutdown_event: Any, # Event + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, +) -> None: + """Send transitions to the learner. This function continuously retrieves messages from the queue and processes: @@ -526,6 +554,13 @@ def send_transitions( - A batch of transitions (observation, action, reward, next observation) is collected. - Transitions are moved to the CPU and serialized using PyTorch. - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + transitions_queue (Queue): The queue to receive the transitions. + shutdown_event (Event): The event to check if the process should shutdown. + learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub. + grpc_channel (grpc.Channel | None): Optional pre-created channel. """ if not use_threads(cfg): @@ -563,18 +598,24 @@ def send_transitions( def send_interactions( cfg: TrainRLServerPipelineConfig, interactions_queue: Queue, - shutdown_event: Event, # type: ignore - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, -) -> services_pb2.Empty: - """ - Sends interactions to the learner. + shutdown_event: Any, # Event + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, +) -> None: + """Send interactions to the learner. This function continuously retrieves messages from the queue and processes: - Interaction Messages: - Contains useful statistics about episodic rewards and policy timings. - The message is serialized using `pickle` and sent to the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + interactions_queue (Queue): The queue to receive the interactions. + shutdown_event (Event): The event to check if the process should shutdown. + learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub. + grpc_channel (grpc.Channel | None): Optional pre-created channel. """ if not use_threads(cfg): @@ -613,7 +654,11 @@ def send_interactions( logging.info("[ACTOR] Interactions process stopped") -def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore +def transitions_stream( + shutdown_event: Any, # Event + transitions_queue: Queue, + timeout: float, +) -> "Generator[Any, None, services_pb2.Empty]": while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=timeout) @@ -629,10 +674,10 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: def interactions_stream( - shutdown_event: Event, + shutdown_event: Any, # Event interactions_queue: Queue, - timeout: float, # type: ignore -) -> services_pb2.Empty: + timeout: float, +) -> "Generator[Any, None, services_pb2.Empty]": while not shutdown_event.is_set(): try: message = interactions_queue.get(block=True, timeout=timeout) @@ -652,7 +697,8 @@ def interactions_stream( # Policy functions -def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): +def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device): + """Drain the latest learner-pushed weights into ``algorithm.policy``.""" bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") @@ -667,18 +713,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) # - Send critic's encoder state when shared_encoder=True # - Skip encoder params entirely when freeze_vision_encoder=True # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) - - # Load actor state dict - actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) - policy.actor.load_state_dict(actor_state_dict) - - # Load discrete critic if present - if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts: - discrete_critic_state_dict = move_state_dict_to_device( - state_dicts["discrete_critic"], device=device - ) - policy.discrete_critic.load_state_dict(discrete_critic_state_dict) - logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") + algorithm.load_weights(state_dicts, device=device) # Utilities functions diff --git a/src/lerobot/rl/algorithms/__init__.py b/src/lerobot/rl/algorithms/__init__.py new file mode 100644 index 000000000..c09bd26fc --- /dev/null +++ b/src/lerobot/rl/algorithms/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sac import SACAlgorithm, SACAlgorithmConfig + +__all__ = [ + "SACAlgorithm", + "SACAlgorithmConfig", +] diff --git a/src/lerobot/rl/algorithms/base.py b/src/lerobot/rl/algorithms/base.py new file mode 100644 index 000000000..01c34584b --- /dev/null +++ b/src/lerobot/rl/algorithms/base.py @@ -0,0 +1,207 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import builtins +import os +from collections.abc import Iterator +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeVar + +import torch +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors +from torch.optim import Optimizer + +from lerobot.types import BatchType +from lerobot.utils.hub import HubMixin + +from .configs import RLAlgorithmConfig, TrainingStats + +if TYPE_CHECKING: + from torch import nn + + from ..data_sources.data_mixer import DataMixer + +T = TypeVar("T", bound="RLAlgorithm") + + +class RLAlgorithm(HubMixin, abc.ABC): + """Base for all RL algorithms.""" + + config_class: type[RLAlgorithmConfig] + name: str + config: RLAlgorithmConfig + + @abc.abstractmethod + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """One complete training step. + + The algorithm calls ``next(batch_iterator)`` as many times as it + needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches. + The iterator is owned by the trainer; the algorithm just consumes + from it. + """ + raise NotImplementedError + + def configure_data_iterator( + self, + data_mixer: DataMixer, + batch_size: int, + *, + async_prefetch: bool = True, + queue_size: int = 2, + ) -> Iterator[BatchType]: + """Create the data iterator this algorithm needs. + + The default implementation uses the standard ``data_mixer.get_iterator()``. + Algorithms that need specialised sampling should override this method. + """ + return data_mixer.get_iterator( + batch_size=batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + @abc.abstractmethod + def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]: + """Build and return the optimizers used during training. + + Called once on the learner side after construction. + """ + raise NotImplementedError + + def get_optimizers(self) -> dict[str, Optimizer]: + """Return optimizers for checkpointing / external scheduling.""" + return {} + + @property + def optimization_step(self) -> int: + """Current learner optimization step. + + Part of the stable contract for checkpoint/resume. Algorithms can + either use this default storage or override for custom behavior. + """ + return getattr(self, "_optimization_step", 0) + + @optimization_step.setter + def optimization_step(self, value: int) -> None: + self._optimization_step = int(value) + + def get_weights(self) -> dict[str, Any]: + """Policy state-dict to push to actors.""" + return {} + + @abc.abstractmethod + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + """Load policy state-dict received from the learner.""" + raise NotImplementedError + + @abc.abstractmethod + def state_dict(self) -> dict[str, torch.Tensor]: + """Algorithm-owned trainable tensors. + + Must return a flat tensor mapping for everything the algorithm owns + that is not part of the policy (e.g. critic ensembles, target networks, + temperature parameters). Algorithms with no training-only tensors + should explicitly return an empty dict. + """ + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict( + self, + state_dict: dict[str, torch.Tensor], + device: str | torch.device = "cpu", + ) -> None: + """In-place load of algorithm-owned tensors. + + Implementations MUST keep the identity of any ``nn.Parameter`` that an + optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()`` + rather than rebinding the attribute. + """ + raise NotImplementedError + + def _save_pretrained(self, save_directory: Path) -> None: + """Persist the algorithm's tensors and config to ``save_directory``. + + Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`) + and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`). + """ + tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()} + save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE)) + self.config._save_pretrained(save_directory) + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + policy: nn.Module, + config: RLAlgorithmConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + device: str | torch.device = "cpu", + **algo_kwargs: Any, + ) -> T: + """Build an algorithm and load its weights from ``pretrained_name_or_path``.""" + if config is None: + config = cls.config_class.from_pretrained( + pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + if hasattr(config, "policy_config"): + config.policy_config = policy.config + + instance = cls(policy=policy, config=config, **algo_kwargs) + + model_id = str(pretrained_name_or_path) + if os.path.isdir(model_id): + model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" + ) from e + + tensors = load_safetensors(model_file) + instance.load_state_dict(tensors, device=device) + return instance diff --git a/src/lerobot/rl/algorithms/configs.py b/src/lerobot/rl/algorithms/configs.py new file mode 100644 index 000000000..9448afeb3 --- /dev/null +++ b/src/lerobot/rl/algorithms/configs.py @@ -0,0 +1,138 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import builtins +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, TypeVar + +import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME +from huggingface_hub.errors import HfHubHTTPError + +from lerobot.utils.hub import HubMixin + +T = TypeVar("T", bound="RLAlgorithmConfig") + +logger = logging.getLogger(__name__) + + +@dataclass +class TrainingStats: + """Returned by ``algorithm.update()`` for logging and checkpointing.""" + + losses: dict[str, float] = field(default_factory=dict) + grad_norms: dict[str, float] = field(default_factory=dict) + extra: dict[str, float] = field(default_factory=dict) + + def to_log_dict(self) -> dict[str, float]: + """Flatten all stats into a single dict for logging.""" + + d: dict[str, float] = {} + for name, val in self.losses.items(): + d[name] = val + for name, val in self.grad_norms.items(): + d[f"{name}_grad_norm"] = val + for name, val in self.extra.items(): + d[name] = val + return d + + +@dataclass +class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): + """Registry for algorithm configs.""" + + @property + def type(self) -> str: + """Registered name of this algorithm config (e.g. ``"sac"``).""" + choice_name = self.get_choice_name(self.__class__) + if not isinstance(choice_name, str): + raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}") + return choice_name + + @classmethod + @abc.abstractmethod + def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig: + """Build an algorithm config from a policy config. + + Must be overridden by every registered config subclass. + """ + raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()") + + def _save_pretrained(self, save_directory: Path) -> None: + """Serialize this config as ``config.json`` inside ``save_directory``.""" + with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[Any, Any] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **algo_kwargs: Any, + ) -> T: + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + if config_file is None: + raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}") + + with draccus.config_type("json"): + instance = draccus.parse(RLAlgorithmConfig, config_file, args=[]) + + if cls is not RLAlgorithmConfig and not isinstance(instance, cls): + raise TypeError( + f"Config at {model_id} has type '{instance.type}' but was loaded via " + f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()." + ) + + for key, value in algo_kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + return instance diff --git a/src/lerobot/rl/algorithms/factory.py b/src/lerobot/rl/algorithms/factory.py new file mode 100644 index 000000000..2a5d9dea7 --- /dev/null +++ b/src/lerobot/rl/algorithms/factory.py @@ -0,0 +1,99 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from .base import RLAlgorithm +from .configs import RLAlgorithmConfig + + +def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig: + """Instantiate an `RLAlgorithmConfig` from its registered type name. + + Args: + algorithm_type: Registry key of the algorithm (e.g. ``"sac"``). + **kwargs: Keyword arguments forwarded to the config class constructor. + + Returns: + An instance of the matching ``RLAlgorithmConfig`` subclass. + + Raises: + ValueError: If ``algorithm_type`` is not registered. + """ + try: + cls = RLAlgorithmConfig.get_choice_class(algorithm_type) + except KeyError as err: + raise ValueError( + f"Algorithm type '{algorithm_type}' is not registered. " + f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}" + ) from err + return cls(**kwargs) + + +def get_algorithm_class(name: str) -> type[RLAlgorithm]: + """ + Retrieves an RL algorithm class by its registered name. + + This function uses dynamic imports to avoid loading all algorithm classes into + memory at once, improving startup time and reducing dependencies. + + Args: + name: The name of the algorithm. Supported names are "sac". + + Returns: + The algorithm class corresponding to the given name. + + Raises: + ValueError: If the algorithm name is not recognized. + """ + if name == "sac": + from .sac.sac_algorithm import SACAlgorithm + + return SACAlgorithm + raise ValueError( + f"Algorithm type '{name}' is not available. " + f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}" + ) + + +def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm: + """ + Instantiate an RL algorithm. + + This factory function looks up the :class:`RLAlgorithm` subclass that matches + ``cfg.type`` and instantiates it with the provided policy. It also enforces + that ``cfg.policy_config`` has been populated before construction (this is + normally handled by :meth:`TrainRLServerPipelineConfig.validate`). + + Args: + cfg: The algorithm configuration. Must have ``policy_config`` set. + policy: The policy module the algorithm will train. + + Returns: + An instantiated :class:`RLAlgorithm`. + + Raises: + ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not + registered. + """ + if getattr(cfg, "policy_config", None) is None: + raise ValueError( + f"{type(cfg).__name__}.policy_config is None. " + "It must be populated (typically by TrainRLServerPipelineConfig.validate) " + "before calling make_algorithm()." + ) + cls = get_algorithm_class(cfg.type) + return cls(policy=policy, config=cfg) diff --git a/src/lerobot/rl/algorithms/sac/__init__.py b/src/lerobot/rl/algorithms/sac/__init__.py new file mode 100644 index 000000000..9d076bcbb --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_sac import SACAlgorithmConfig +from .sac_algorithm import SACAlgorithm + +__all__ = ["SACAlgorithm", "SACAlgorithmConfig"] diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py new file mode 100644 index 000000000..c4e9b334a --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -0,0 +1,99 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import ( + CriticNetworkConfig, + GaussianActorConfig, +) + +from ..configs import RLAlgorithmConfig + + +@RLAlgorithmConfig.register_subclass("sac") +@dataclass +class SACAlgorithmConfig(RLAlgorithmConfig): + """Soft Actor-Critic (SAC) algorithm configuration. + + SAC is an off-policy actor-critic deep RL algorithm based on the maximum + entropy reinforcement learning framework. It learns a policy and a Q-function + simultaneously using experience collected from the environment. + + This configuration class contains the algorithm-side hyperparameters: critic + ensemble, target networks, temperature / entropy tuning, and the Bellman + update loop. The policy-side (actor + observation encoder) lives in + :class:`~lerobot.policies.gaussian_actor.GaussianActorConfig` and is + referenced via :attr:`policy_config`. + """ + + # Optimizer learning rates + # Learning rate for the actor network + actor_lr: float = 3e-4 + # Learning rate for the critic network + critic_lr: float = 3e-4 + # Learning rate for the temperature parameter + temperature_lr: float = 3e-4 + + # Bellman update + # Discount factor for the SAC algorithm + discount: float = 0.99 + # Whether to use backup entropy for the SAC algorithm + use_backup_entropy: bool = True + # Weight for the critic target update + critic_target_update_weight: float = 0.005 + + # Critic ensemble + # Number of critics in the ensemble + num_critics: int = 2 + # Number of subsampled critics for training + num_subsample_critics: int | None = None + # Configuration for the critic network architecture + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for the discrete critic network + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + + # Temperature / entropy + # Initial temperature value + temperature_init: float = 1.0 + # Target entropy for automatic temperature tuning. If ``None``, defaults to + # ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if + # there is a discrete action head). + target_entropy: float | None = None + + # Update loop + # Update-to-data ratio. Set to >1 to enable extra critic updates per env step. + utd_ratio: int = 1 + # Frequency of policy updates + policy_update_freq: int = 1 + # Gradient clipping norm for the SAC algorithm + grad_clip_norm: float = 40.0 + + # Optimizations + # torch.compile is currently disabled by default + use_torch_compile: bool = False + + # Policy config + policy_config: PreTrainedConfig | None = None + + @classmethod + def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: + """Build an algorithm config with default hyperparameters for a given policy.""" + return cls( + policy_config=policy_cfg, + discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, + ) diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py new file mode 100644 index 000000000..81c44068f --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -0,0 +1,672 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Callable, Iterator +from dataclasses import asdict +from typing import Any + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.optim import Optimizer + +from lerobot.policies.gaussian_actor.modeling_gaussian_actor import ( + DISCRETE_DIMENSION_INDEX, + MLP, + DiscreteCritic, + GaussianActorObservationEncoder, + GaussianActorPolicy, + orthogonal_init, +) +from lerobot.policies.utils import get_device_from_parameters +from lerobot.types import BatchType +from lerobot.utils.constants import ACTION +from lerobot.utils.transition import move_state_dict_to_device + +from ..base import RLAlgorithm +from ..configs import TrainingStats +from .configuration_sac import SACAlgorithmConfig + + +class SACAlgorithm(RLAlgorithm): + """Soft Actor-Critic. Owns critics, targets, temperature, and loss computation.""" + + config_class = SACAlgorithmConfig + name = "sac" + + def __init__( + self, + policy: GaussianActorPolicy, + config: SACAlgorithmConfig, + ): + self.config = config + self.policy_config = config.policy_config + self.policy = policy + self.optimizers: dict[str, Optimizer] = {} + self._optimization_step: int = 0 + + action_dim = self.policy.config.output_features[ACTION].shape[0] + self._init_critics(action_dim) + self._init_temperature(action_dim) + + self._device = torch.device(self.policy.config.device) + self._move_to_device() + + def _init_critics(self, action_dim) -> None: + """Build critic ensemble, targets.""" + encoder = self.policy.encoder_critic + + heads = [ + CriticHead( + input_dim=encoder.output_dim + action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads) + target_heads = [ + CriticHead( + input_dim=encoder.output_dim + action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + # TODO(Khalil): Investigate and fix torch.compile + # NOTE: torch.compile is disabled, policy does not converge when enabled. + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + self.discrete_critic_target = None + if self.policy_config.num_discrete_actions is not None: + self.discrete_critic_target = self._init_discrete_critic_target(encoder) + + def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic: + """Build target discrete critic (main network is owned by the policy).""" + discrete_critic_target = DiscreteCritic( + encoder=encoder, + input_dim=encoder.output_dim, + output_dim=self.policy_config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + # TODO(Khalil): Compile the discrete critic + discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict()) + return discrete_critic_target + + def _init_temperature(self, continuous_action_dim: int) -> None: + """Set up temperature parameter (log_alpha) and target entropy.""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: + total_action_dim = continuous_action_dim + ( + 1 if self.policy_config.num_discrete_actions is not None else 0 + ) + self.target_entropy = -total_action_dim / 2 + + def _move_to_device(self) -> None: + self.policy.to(self._device) + self.critic_ensemble.to(self._device) + self.critic_target.to(self._device) + self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device)) + if self.discrete_critic_target is not None: + self.discrete_critic_target.to(self._device) + + @property + def temperature(self) -> float: + """Return the current temperature value, always in sync with log_alpha.""" + return self.log_alpha.exp().item() + + def _critic_forward( + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, + ) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + + critics = self.critic_target if use_target else self.critic_ensemble + q_values = critics(observations, actions, observation_features) + return q_values + + def _discrete_critic_forward( + self, observations, use_target=False, observation_features=None + ) -> torch.Tensor: + """Forward pass through a discrete critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the discrete critic network + """ + discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic + q_values = discrete_critic(observations, observation_features) + return q_values + + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """Run one SAC training step (critic / discrete-critic / actor / temperature). + + Pulls ``utd_ratio`` batches from ``batch_iterator``, computes the relevant + losses, backpropagates each, and updates target networks. + + Args: + batch_iterator: yields batches each containing + - ``action``: Action tensor + - ``reward``: Reward tensor + - ``state``: Observations tensor dict + - ``next_state``: Next observations tensor dict + - ``done``: Done mask tensor + - ``observation_feature``: Optional pre-computed observation features + - ``next_observation_feature``: Optional pre-computed next observation features + - ``complementary_info`` (optional): per-step extras like discrete penalties + + Returns: + TrainingStats with per-component losses and grad norms. + """ + clip = self.config.grad_clip_norm + + for _ in range(self.config.utd_ratio - 1): + batch = next(batch_iterator) + fb = self._prepare_forward_batch(batch, include_complementary_info=True) + + loss_critic = self._compute_loss_critic(fb) + self.optimizers["critic"].zero_grad() + loss_critic.backward() + torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip) + self.optimizers["critic"].step() + + if self.policy_config.num_discrete_actions is not None: + loss_dc = self._compute_loss_discrete_critic(fb) + self.optimizers["discrete_critic"].zero_grad() + loss_dc.backward() + torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip) + self.optimizers["discrete_critic"].step() + + self._update_target_networks() + + batch = next(batch_iterator) + fb = self._prepare_forward_batch(batch, include_complementary_info=False) + + loss_critic = self._compute_loss_critic(fb) + self.optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item() + self.optimizers["critic"].step() + + stats = TrainingStats( + losses={"loss_critic": loss_critic.item()}, + grad_norms={"critic": critic_grad}, + ) + + if self.policy_config.num_discrete_actions is not None: + loss_dc = self._compute_loss_discrete_critic(fb) + self.optimizers["discrete_critic"].zero_grad() + loss_dc.backward() + dc_grad = torch.nn.utils.clip_grad_norm_( + self.policy.discrete_critic.parameters(), max_norm=clip + ).item() + self.optimizers["discrete_critic"].step() + stats.losses["loss_discrete_critic"] = loss_dc.item() + stats.grad_norms["discrete_critic"] = dc_grad + + if self._optimization_step % self.config.policy_update_freq == 0: + for _ in range(self.config.policy_update_freq): + loss_actor = self._compute_loss_actor(fb) + self.optimizers["actor"].zero_grad() + loss_actor.backward() + actor_grad = torch.nn.utils.clip_grad_norm_( + self.policy.actor.parameters(), max_norm=clip + ).item() + self.optimizers["actor"].step() + + loss_temp = self._compute_loss_temperature(fb) + self.optimizers["temperature"].zero_grad() + loss_temp.backward() + temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item() + self.optimizers["temperature"].step() + + stats.losses["loss_actor"] = loss_actor.item() + stats.losses["loss_temperature"] = loss_temp.item() + stats.grad_norms["actor"] = actor_grad + stats.grad_norms["temperature"] = temp_grad + stats.extra["temperature"] = self.temperature + + self._update_target_networks() + self._optimization_step += 1 + return stats + + def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor: + # Extract common components from batch + observations = batch["state"] + actions = batch[ACTION] + observation_features = batch.get("observation_feature") + # Extract critic-specific components + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + next_observation_features = batch.get("next_observation_feature") + + with torch.no_grad(): + next_action_preds, next_log_probs, _ = self.policy.actor( + next_observations, next_observation_features + ) + + # 2- compute q targets + q_targets = self._critic_forward( + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, + ) + + # subsample critics to prevent overfitting if use high UTD (update to date) + # TODO: Get indices before forward pass to avoid unnecessary computation + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + if self.config.use_backup_entropy: + min_q = min_q - (self.temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q + + # 3- compute predicted qs + if self.policy_config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self._critic_forward( + observations=observations, + actions=actions, + use_target=False, + observation_features=observation_features, + ) + + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up + critics_loss = ( + F.mse_loss( + input=q_preds, + target=td_target_duplicate, + reduction="none", + ).mean(dim=1) + ).sum() + return critics_loss + + def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + actions = batch[ACTION] + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + observation_features = batch.get("observation_feature") + next_observation_features = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) + actions_discrete = actions_discrete.long() + + discrete_penalties: Tensor | None = None + if complementary_info is not None: + discrete_penalties = complementary_info.get("discrete_penalty") + + with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network + next_discrete_qs = self._discrete_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) + + # Get target Q-values from target network + target_next_discrete_qs = self._discrete_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) + + # Use gather to select Q-values for best actions + target_next_discrete_q = torch.gather( + target_next_discrete_qs, dim=1, index=best_next_discrete_action + ).squeeze(-1) + + # Compute target Q-value with Bellman equation + rewards_discrete = rewards + if discrete_penalties is not None: + rewards_discrete = rewards + discrete_penalties + target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q + + # Get predicted Q-values for current observations + predicted_discrete_qs = self._discrete_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) + + # Use gather to select Q-values for taken actions + predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) + + # Compute MSE loss between predicted and target Q-values + discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) + return discrete_critic_loss + + def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + observation_features = batch.get("observation_feature") + + actions_pi, log_probs, _ = self.policy.actor(observations, observation_features) + + q_preds = self._critic_forward( + observations=observations, + actions=actions_pi, + use_target=False, + observation_features=observation_features, + ) + min_q_preds = q_preds.min(dim=0)[0] + + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() + return actor_loss + + def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor: + """Compute the temperature loss""" + observations = batch["state"] + observation_features = batch.get("observation_feature") + + # calculate temperature loss + with torch.no_grad(): + _, log_probs, _ = self.policy.actor(observations, observation_features) + + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() + return temperature_loss + + def _update_target_networks(self) -> None: + """Update target networks with exponential moving average""" + for target_p, p in zip( + self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True + ): + target_p.data.copy_( + p.data * self.config.critic_target_update_weight + + target_p.data * (1.0 - self.config.critic_target_update_weight) + ) + if self.policy_config.num_discrete_actions is not None: + for target_p, p in zip( + self.discrete_critic_target.parameters(), + self.policy.discrete_critic.parameters(), + strict=True, + ): + target_p.data.copy_( + p.data * self.config.critic_target_update_weight + + target_p.data * (1.0 - self.config.critic_target_update_weight) + ) + + def _prepare_forward_batch( + self, batch: BatchType, *, include_complementary_info: bool = True + ) -> dict[str, Any]: + observations = batch["state"] + next_observations = batch["next_state"] + observation_features, next_observation_features = self.get_observation_features( + observations, next_observations + ) + forward_batch: dict[str, Any] = { + ACTION: batch[ACTION], + "reward": batch["reward"], + "state": observations, + "next_state": next_observations, + "done": batch["done"], + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + if include_complementary_info and "complementary_info" in batch: + forward_batch["complementary_info"] = batch["complementary_info"] + return forward_batch + + def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]: + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + NOTE: + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + A dictionary mapping component names ("actor", "critic", "temperature") + to their respective Adam optimizers. + """ + actor_params = self.policy.get_optim_params()["actor"] + self.optimizers = { + "actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr), + "critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr), + "temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr), + } + if self.policy_config.num_discrete_actions is not None: + self.optimizers["discrete_critic"] = torch.optim.Adam( + self.policy.discrete_critic.parameters(), lr=self.config.critic_lr + ) + return self.optimizers + + def get_optimizers(self) -> dict[str, Optimizer]: + return self.optimizers + + def get_weights(self) -> dict[str, Any]: + """Send actor + discrete-critic state dicts.""" + state_dicts: dict[str, Any] = { + "policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"), + } + if self.policy_config.num_discrete_actions is not None: + state_dicts["discrete_critic"] = move_state_dict_to_device( + self.policy.discrete_critic.state_dict(), device="cpu" + ) + return state_dicts + + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + """Load actor + discrete-critic weights into the policy.""" + actor_sd = move_state_dict_to_device(weights["policy"], device=device) + self.policy.actor.load_state_dict(actor_sd) + if "discrete_critic" in weights and self.policy.discrete_critic is not None: + discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device) + self.policy.discrete_critic.load_state_dict(discrete_sd) + + def state_dict(self) -> dict[str, torch.Tensor]: + """Algorithm-owned trainable tensors. + + Encoder weights are stripped because they are owned by the policy + (``policy.encoder_critic``) and already saved via ``policy.save_pretrained``. + """ + bundle: dict[str, torch.Tensor] = {} + for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items(): + bundle[f"critic_ensemble.{k}"] = v + for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items(): + bundle[f"critic_target.{k}"] = v + if self.discrete_critic_target is not None: + for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items(): + bundle[f"discrete_critic_target.{k}"] = v + bundle["log_alpha"] = self.log_alpha.detach() + return bundle + + def load_state_dict( + self, + state_dict: dict[str, torch.Tensor], + device: str | torch.device = "cpu", + ) -> None: + """In-place load of algorithm-owned tensors. + + ``log_alpha`` is restored via ``Parameter.data.copy_`` so the + ``temperature`` optimizer's reference to the parameter object stays + valid after resume. + """ + critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.") + critic_target_state = _split_prefix(state_dict, "critic_target.") + self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False) + self.critic_target.load_state_dict(critic_target_state, strict=False) + + if self.discrete_critic_target is not None: + discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.") + self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False) + + if "log_alpha" in state_dict: + self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device)) + + def get_observation_features( + self, observations: Tensor, next_observations: Tensor + ) -> tuple[Tensor | None, Tensor | None]: + """ + Get observation features from the policy encoder. It act as cache for the observation features. + when the encoder is frozen, the observation features are not updated. + We can save compute by caching the observation features. + + Args: + policy: The policy model + observations: The current observations + next_observations: The next observations + + Returns: + tuple: observation_features, next_observation_features + """ + + if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = self.policy.actor.encoder.get_cached_image_features(observations) + next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations) + + return observation_features, next_observation_features + + +def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Drop ``encoder.*`` keys from a critic-module state dict.""" + return {k: v for k, v in state.items() if not k.startswith("encoder.")} + + +def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: + """Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped.""" + return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)} + + +class CriticHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + +class CriticEnsemble(nn.Module): + """ + CriticEnsemble wraps multiple CriticHead modules into an ensemble. + + Args: + encoder (GaussianActorObservationEncoder): encoder for observations. + ensemble (List[CriticHead]): list of critic heads. + init_final (float | None): optional initializer scale for final layers. + + Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. + """ + + def __init__( + self, + encoder: GaussianActorObservationEncoder, + ensemble: list[CriticHead], + init_final: float | None = None, + ): + super().__init__() + self.encoder = encoder + self.init_final = init_final + self.critics = nn.ModuleList(ensemble) + + def forward( + self, + observations: dict[str, torch.Tensor], + actions: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + + obs_enc = self.encoder(observations, cache=observation_features) + + inputs = torch.cat([obs_enc, actions], dim=-1) + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 05b8419bd..cec80b723 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -97,8 +97,8 @@ class ReplayBuffer: Args: capacity (int): Maximum number of transitions to store in the buffer. device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). - state_keys (List[str]): The list of keys that appear in `state` and `next_state`. - image_augmentation_function (Optional[Callable]): A function that takes a batch of images + state_keys (list[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Callable | None): A function that takes a batch of images and returns a batch of augmented images. If None, a default augmentation function is used. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. @@ -634,7 +634,7 @@ class ReplayBuffer: If None, you must handle or define default keys. Returns: - transitions (List[Transition]): + transitions (list[Transition]): A list of Transition dictionaries with the same length as `dataset`. """ if state_keys is None: diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index cc808bcb0..eece13a4c 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset( Args: original_dataset (LeRobotDataset): The source dataset. - crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + crop_params_dict (dict[str, Tuple[int, int, int, int]]): A dictionary mapping observation keys to crop parameters (top, left, height, width). new_repo_id (str): Repository id for the new dataset. new_dataset_root (str): The root directory where the new dataset will be written. - resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + resize_size (tuple[int, int], optional): The target size (height, width) after cropping. Defaults to (128, 128). Returns: diff --git a/src/lerobot/rl/data_sources/__init__.py b/src/lerobot/rl/data_sources/__init__.py new file mode 100644 index 000000000..97cfe5001 --- /dev/null +++ b/src/lerobot/rl/data_sources/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.types import BatchType + +from .data_mixer import DataMixer, OnlineOfflineMixer + +__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"] diff --git a/src/lerobot/rl/data_sources/data_mixer.py b/src/lerobot/rl/data_sources/data_mixer.py new file mode 100644 index 000000000..57a2d86be --- /dev/null +++ b/src/lerobot/rl/data_sources/data_mixer.py @@ -0,0 +1,97 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc + +from lerobot.types import BatchType + +from ..buffer import ReplayBuffer, concatenate_batch_transitions + + +class DataMixer(abc.ABC): + """Abstract interface for all data mixing strategies.""" + + @abc.abstractmethod + def sample(self, batch_size: int) -> BatchType: + """Draw one batch of ``batch_size`` transitions.""" + raise NotImplementedError + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """Infinite iterator that yields batches.""" + while True: + yield self.sample(batch_size) + + +class OnlineOfflineMixer(DataMixer): + """Mixes transitions from an online and an offline replay buffer.""" + + def __init__( + self, + online_buffer: ReplayBuffer, + offline_buffer: ReplayBuffer | None = None, + online_ratio: float = 1.0, + ): + if not 0.0 <= online_ratio <= 1.0: + raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}") + self.online_buffer = online_buffer + self.offline_buffer = offline_buffer + self.online_ratio = online_ratio + + def sample(self, batch_size: int) -> BatchType: + if self.offline_buffer is None: + return self.online_buffer.sample(batch_size) + + n_online = max(1, int(batch_size * self.online_ratio)) + n_offline = batch_size - n_online + + online_batch = self.online_buffer.sample(n_online) + offline_batch = self.offline_buffer.sample(n_offline) + return concatenate_batch_transitions(online_batch, offline_batch) + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """Yield batches by composing buffer async iterators.""" + + n_online = max(1, int(batch_size * self.online_ratio)) + + online_iter = self.online_buffer.get_iterator( + batch_size=n_online, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + if self.offline_buffer is None: + yield from online_iter + return + + n_offline = batch_size - n_online + offline_iter = self.offline_buffer.get_iterator( + batch_size=n_offline, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + while True: + yield concatenate_batch_transitions(next(online_iter), next(offline_iter)) diff --git a/src/lerobot/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py index 4398351c5..0f42d7573 100644 --- a/src/lerobot/rl/eval_policy.py +++ b/src/lerobot/rl/eval_policy.py @@ -17,7 +17,6 @@ import logging from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets import LeRobotDataset from lerobot.policies import make_policy from lerobot.robots import ( # noqa: F401 @@ -31,6 +30,7 @@ from lerobot.teleoperators import ( ) from .gym_manipulator import make_robot_env +from .train_rl import TrainRLServerPipelineConfig logging.basicConfig(level=logging.INFO) diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 2190070f5..03f7b4eea 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -74,6 +74,7 @@ from lerobot.teleoperators import ( from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD +from lerobot.utils.import_utils import require_package from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -312,6 +313,7 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]: # Check if this is a GymHIL simulation environment if cfg.name == "gym_hil": assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop" + require_package("gym-hil", extra="hilserl", import_name="gym_hil") import gym_hil # noqa: F401 # Extract gripper settings with defaults @@ -383,10 +385,21 @@ def make_processors( GymHILAdapterProcessorStep(), Numpy2TorchActionProcessorStep(), VanillaObservationProcessorStep(), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=device), ] + # Add time limit processor if reset config exists + if cfg.processor.reset is not None: + env_pipeline_steps.append( + TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps)) + ) + + env_pipeline_steps.extend( + [ + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=device), + ] + ) + return DataProcessorPipeline( steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition ), DataProcessorPipeline( @@ -551,8 +564,19 @@ def step_env_and_process_transition( terminated = terminated or processed_action_transition[TransitionKey.DONE] truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() + + if hasattr(env, "get_raw_joint_positions"): + raw_joint_positions = env.get_raw_joint_positions() + if raw_joint_positions is not None: + complementary_data["raw_joint_positions"] = raw_joint_positions + + # Merge env and action-processor info: env wins for str keys, action-processor + # wins for `TeleopEvents` enum keys + action_info = processed_action_transition[TransitionKey.INFO] new_info = info.copy() - new_info.update(processed_action_transition[TransitionKey.INFO]) + for key, value in action_info.items(): + if isinstance(key, TeleopEvents): + new_info[key] = value new_transition = create_transition( observation=obs, @@ -568,6 +592,24 @@ def step_env_and_process_transition( return new_transition +def reset_and_build_transition( + env: gym.Env, + env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + action_processor: DataProcessorPipeline[EnvTransition, EnvTransition], +) -> EnvTransition: + """Reset env + processors and return the first env-processed transition.""" + obs, info = env.reset() + env_processor.reset() + action_processor.reset() + complementary_data: dict[str, Any] = {} + if hasattr(env, "get_raw_joint_positions"): + raw_joint_positions = env.get_raw_joint_positions() + if raw_joint_positions is not None: + complementary_data["raw_joint_positions"] = raw_joint_positions + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + return env_processor(data=transition) + + def control_loop( env: gym.Env, env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], @@ -593,17 +635,7 @@ def control_loop( print("- When not intervening, robot will stay still") print("- Press Ctrl+C to exit") - # Reset environment and processors - obs, info = env.reset() - complementary_data = ( - {"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {} - ) - env_processor.reset() - action_processor.reset() - - # Process initial observation - transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) - transition = env_processor(data=transition) + transition = reset_and_build_transition(env, env_processor, action_processor) # Determine if gripper is used use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True @@ -659,79 +691,82 @@ def control_loop( episode_step = 0 episode_start_time = time.perf_counter() - while episode_idx < cfg.dataset.num_episodes_to_record: - step_start_time = time.perf_counter() + try: + while episode_idx < cfg.dataset.num_episodes_to_record: + step_start_time = time.perf_counter() - # Create a neutral action (no movement) - neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) - if use_gripper: - neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay + # Create a neutral action (no movement) + neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + if use_gripper: + neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay - # Use the new step function - transition = step_env_and_process_transition( - env=env, - transition=transition, - action=neutral_action, - env_processor=env_processor, - action_processor=action_processor, - ) - terminated = transition.get(TransitionKey.DONE, False) - truncated = transition.get(TransitionKey.TRUNCATED, False) - - if cfg.mode == "record": - observations = { + observation = { k: v.squeeze(0).cpu() for k, v in transition[TransitionKey.OBSERVATION].items() if isinstance(v, torch.Tensor) } - # Use teleop_action if available, otherwise use the action from the transition - action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( - "teleop_action", transition[TransitionKey.ACTION] + + transition = step_env_and_process_transition( + env=env, + transition=transition, + action=neutral_action, + env_processor=env_processor, + action_processor=action_processor, ) - frame = { - **observations, - ACTION: action_to_record.cpu(), - REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), - DONE: np.array([terminated or truncated], dtype=bool), - } - if use_gripper: - discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) - frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32) + terminated = transition.get(TransitionKey.DONE, False) + truncated = transition.get(TransitionKey.TRUNCATED, False) - if dataset is not None: - frame["task"] = cfg.dataset.task - dataset.add_frame(frame) + if cfg.mode == "record": + action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "teleop_action", transition[TransitionKey.ACTION] + ) + frame = { + **observation, + ACTION: action_to_record.cpu(), + REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + DONE: np.array([terminated or truncated], dtype=bool), + } + if use_gripper: + discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "discrete_penalty", 0.0 + ) + frame["complementary_info.discrete_penalty"] = np.array( + [discrete_penalty], dtype=np.float32 + ) - episode_step += 1 + if dataset is not None: + frame["task"] = cfg.dataset.task + dataset.add_frame(frame) - # Handle episode termination - if terminated or truncated: - episode_time = time.perf_counter() - episode_start_time - logging.info( - f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" - ) - episode_step = 0 - episode_idx += 1 + episode_step += 1 - if dataset is not None: - if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): - logging.info(f"Re-recording episode {episode_idx}") - dataset.clear_episode_buffer() - episode_idx -= 1 - else: - logging.info(f"Saving episode {episode_idx}") - dataset.save_episode() + # Handle episode termination + if terminated or truncated: + episode_time = time.perf_counter() - episode_start_time + logging.info( + f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" + ) + episode_step = 0 + episode_idx += 1 - # Reset for new episode - obs, info = env.reset() - env_processor.reset() - action_processor.reset() + if dataset is not None: + if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): + logging.info(f"Re-recording episode {episode_idx}") + dataset.clear_episode_buffer() + episode_idx -= 1 + else: + logging.info(f"Saving episode {episode_idx}") + dataset.save_episode() - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + # Reset for new episode + transition = reset_and_build_transition(env, env_processor, action_processor) - # Maintain fps timing - precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) + # Maintain fps timing + precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) + finally: + if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None: + logging.info("Waiting for image writer to finish...") + dataset.writer.image_writer.stop() if dataset is not None and cfg.dataset.push_to_hub: logging.info("Finalizing dataset before pushing to hub") diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 14542576d..41cfd8c03 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -51,9 +51,21 @@ import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from pprint import pformat +from typing import TYPE_CHECKING, Any + +from lerobot.utils.import_utils import _grpc_available, require_package + +if TYPE_CHECKING or _grpc_available: + import grpc + + from lerobot.transport import services_pb2_grpc +else: + grpc = None + services_pb2_grpc = None -import grpc import torch +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from safetensors.torch import load_file as load_safetensors from termcolor import colored from torch import nn from torch.multiprocessing import Queue @@ -68,14 +80,11 @@ from lerobot.common.train_utils import ( ) from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser -from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets import LeRobotDataset, make_dataset -from lerobot.policies import make_policy -from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.policies import make_policy, make_pre_post_processors from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents -from lerobot.transport import services_pb2_grpc from lerobot.transport.utils import ( MAX_MESSAGE_SIZE, bytes_to_python_object, @@ -84,26 +93,35 @@ from lerobot.transport.utils import ( ) from lerobot.utils.constants import ( ACTION, + ALGORITHM_DIR, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, + TRAINING_STEP, ) from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed -from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( format_big_number, init_logging, ) -from .buffer import ReplayBuffer, concatenate_batch_transitions +from .algorithms.base import RLAlgorithm +from .algorithms.factory import make_algorithm +from .buffer import ReplayBuffer +from .data_sources import OnlineOfflineMixer from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService +from .train_rl import TrainRLServerPipelineConfig +from .trainer import RLTrainer @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): + # Fail fast with a friendly error if the optional ``hilserl`` extra is missing. + require_package("grpcio", extra="hilserl", import_name="grpc") if not use_threads(cfg): import torch.multiprocessing as mp @@ -179,7 +197,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): def start_learner_threads( cfg: TrainRLServerPipelineConfig, wandb_logger: WandBLogger | None, - shutdown_event: any, # Event, + shutdown_event: Any, # Event ) -> None: """ Start the learner threads for training. @@ -253,7 +271,7 @@ def start_learner_threads( def add_actor_information_and_train( cfg: TrainRLServerPipelineConfig, wandb_logger: WandBLogger | None, - shutdown_event: any, # Event, + shutdown_event: Any, # Event transition_queue: Queue, interaction_message_queue: Queue, parameters_queue: Queue, @@ -266,8 +284,8 @@ def add_actor_information_and_train( - Transfers transitions from the actor to the replay buffer. - Logs received interaction messages. - Ensures training begins only when the replay buffer has a sufficient number of transitions. - - Samples batches from the replay buffer and performs multiple critic updates. - - Periodically updates the actor, critic, and temperature optimizers. + - Delegates training updates to an ``RLAlgorithm``. + - Periodically pushes updated weights to actors. - Logs training statistics, including loss values and optimization frequency. NOTE: This function doesn't have a single responsibility, it should be split into multiple functions @@ -286,17 +304,13 @@ def add_actor_information_and_train( # of 7% device = get_safe_torch_device(try_device=cfg.policy.device, log=True) storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) - clip_grad_norm_value = cfg.policy.grad_clip_norm online_step_before_learning = cfg.policy.online_step_before_learning - utd_ratio = cfg.policy.utd_ratio fps = cfg.env.fps log_freq = cfg.log_freq save_freq = cfg.save_freq - policy_update_freq = cfg.policy.policy_update_freq policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps - async_prefetch = cfg.policy.async_prefetch # Initialize logging for multiprocessing if not use_threads(cfg): @@ -308,7 +322,7 @@ def add_actor_information_and_train( logging.info("Initializing policy") - policy: SACPolicy = make_policy( + policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) @@ -317,15 +331,17 @@ def add_actor_information_and_train( policy.train() - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + dataset_stats=cfg.policy.dataset_stats, + ) + + # Push initial policy weights to actors + push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm) last_time_policy_pushed = time.time() - optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) - - # If we are resuming, we need to load the training state - resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) - log_training_info(cfg=cfg, policy=policy) replay_buffer = initialize_replay_buffer(cfg, device, storage_device) @@ -338,21 +354,37 @@ def add_actor_information_and_train( device=device, storage_device=storage_device, ) - batch_size: int = batch_size // 2 # We will sample from both replay buffer + + # DataMixer: online-only or online/offline 50-50 mix + data_mixer = OnlineOfflineMixer( + online_buffer=replay_buffer, + offline_buffer=offline_replay_buffer, + online_ratio=cfg.online_ratio, + ) + # RLTrainer owns the iterator, preprocessor, and creates optimizers. + trainer = RLTrainer( + algorithm=algorithm, + data_mixer=data_mixer, + batch_size=batch_size, + preprocessor=preprocessor, + ) + + # If we are resuming, we need to load the training state + optimizers = algorithm.get_optimizers() + resume_optimization_step, resume_interaction_step = load_training_state( + cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device + ) logging.info("Starting learner thread") interaction_message = None optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + algorithm.optimization_step = optimization_step interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 dataset_repo_id = None if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id - # Initialize iterators - online_iterator = None - offline_iterator = None - # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -365,7 +397,6 @@ def add_actor_information_and_train( transition_queue=transition_queue, replay_buffer=replay_buffer, offline_replay_buffer=offline_replay_buffer, - device=device, dataset_repo_id=dataset_repo_id, shutdown_event=shutdown_event, ) @@ -382,180 +413,20 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue - if online_iterator is None: - online_iterator = replay_buffer.get_iterator( - batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 - ) - - if offline_replay_buffer is not None and offline_iterator is None: - offline_iterator = offline_replay_buffer.get_iterator( - batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 - ) - time_for_one_optimization_step = time.time() - for _ in range(utd_ratio - 1): - # Sample from the iterators - batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch[ACTION] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - - # Create a batch dictionary with all required elements for the forward method - forward_batch = { - ACTION: actions, - "reward": rewards, - "state": observations, - "next_state": next_observations, - "done": done, - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - "complementary_info": batch["complementary_info"], - } - - # Use the forward method for critic loss - critic_output = policy.forward(forward_batch, model="critic") - - # Main critic optimization - loss_critic = critic_output["loss_critic"] - optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ) - optimizers["critic"].step() - - # Discrete critic optimization (if available) - if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") - loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] - optimizers["discrete_critic"].zero_grad() - loss_discrete_critic.backward() - discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value - ) - optimizers["discrete_critic"].step() - - # Update target networks (main and discrete) - policy.update_target_networks() - - # Sample for the last update in the UTD ratio - batch = next(online_iterator) - - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch[ACTION] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - - # Create a batch dictionary with all required elements for the forward method - forward_batch = { - ACTION: actions, - "reward": rewards, - "state": observations, - "next_state": next_observations, - "done": done, - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - } - - critic_output = policy.forward(forward_batch, model="critic") - - loss_critic = critic_output["loss_critic"] - optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["critic"].step() - - # Initialize training info dictionary - training_infos = { - "loss_critic": loss_critic.item(), - "critic_grad_norm": critic_grad_norm, - } - - # Discrete critic optimization (if available) - if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") - loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] - optimizers["discrete_critic"].zero_grad() - loss_discrete_critic.backward() - discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["discrete_critic"].step() - - # Add discrete critic info to training info - training_infos["loss_discrete_critic"] = loss_discrete_critic.item() - training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm - - # Actor and temperature optimization (at specified frequency) - if optimization_step % policy_update_freq == 0: - for _ in range(policy_update_freq): - # Actor optimization - actor_output = policy.forward(forward_batch, model="actor") - loss_actor = actor_output["loss_actor"] - optimizers["actor"].zero_grad() - loss_actor.backward() - actor_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value - ).item() - optimizers["actor"].step() - - # Add actor info to training info - training_infos["loss_actor"] = loss_actor.item() - training_infos["actor_grad_norm"] = actor_grad_norm - - # Temperature optimization - temperature_output = policy.forward(forward_batch, model="temperature") - loss_temperature = temperature_output["loss_temperature"] - optimizers["temperature"].zero_grad() - loss_temperature.backward() - temp_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=[policy.log_alpha], max_norm=clip_grad_norm_value - ).item() - optimizers["temperature"].step() - - # Add temperature info to training info - training_infos["loss_temperature"] = loss_temperature.item() - training_infos["temperature_grad_norm"] = temp_grad_norm - training_infos["temperature"] = policy.temperature + # One training step (trainer owns data_mixer iterator; algorithm owns UTD loop) + stats = trainer.training_step() # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm) last_time_policy_pushed = time.time() - # Update target networks (main and discrete) - policy.update_target_networks() + training_infos = stats.to_log_dict() # Log training metrics at specified intervals + optimization_step = algorithm.optimization_step if optimization_step % log_freq == 0: training_infos["replay_buffer_size"] = len(replay_buffer) if offline_replay_buffer is not None: @@ -583,7 +454,6 @@ def add_actor_information_and_train( custom_step_key="Optimization step", ) - optimization_step += 1 if optimization_step % log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") @@ -597,9 +467,12 @@ def add_actor_information_and_train( policy=policy, optimizers=optimizers, replay_buffer=replay_buffer, + algorithm=algorithm, offline_replay_buffer=offline_replay_buffer, dataset_repo_id=dataset_repo_id, fps=fps, + preprocessor=preprocessor, + postprocessor=postprocessor, ) @@ -607,7 +480,7 @@ def start_learner( parameters_queue: Queue, transition_queue: Queue, interaction_message_queue: Queue, - shutdown_event: any, # Event, + shutdown_event: Any, # Event cfg: TrainRLServerPipelineConfig, ): """ @@ -681,9 +554,12 @@ def save_training_checkpoint( policy: nn.Module, optimizers: dict[str, Optimizer], replay_buffer: ReplayBuffer, + algorithm: RLAlgorithm | None = None, offline_replay_buffer: ReplayBuffer | None = None, dataset_repo_id: str | None = None, fps: int = 30, + preprocessor=None, + postprocessor=None, ) -> None: """ Save training checkpoint and associated data. @@ -707,6 +583,8 @@ def save_training_checkpoint( offline_replay_buffer: Optional offline replay buffer to save dataset_repo_id: Repository ID for dataset fps: Frames per second for dataset + preprocessor: Optional preprocessor pipeline to save + postprocessor: Optional postprocessor pipeline to save """ logging.info(f"Checkpoint policy after step {optimization_step}") _num_digits = max(6, len(str(online_steps))) @@ -715,7 +593,7 @@ def save_training_checkpoint( # Create checkpoint directory checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) - # Save checkpoint + # Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/). save_checkpoint( checkpoint_dir=checkpoint_dir, step=optimization_step, @@ -723,13 +601,22 @@ def save_training_checkpoint( policy=policy, optimizer=optimizers, scheduler=None, + preprocessor=preprocessor, + postprocessor=postprocessor, ) - # Save interaction step manually - training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) - os.makedirs(training_state_dir, exist_ok=True) - training_state = {"step": optimization_step, "interaction_step": interaction_step} - torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + # Algorithm-owned tensors live in their own component subfolder + # so they can be `push_to_hub`'d independently and don't bloat the inference artifact. + if algorithm is not None: + algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR) + + # Enrich training_step.json with the RL-specific interaction_step counter so + # both can be restored from a single file. + training_state_dir = checkpoint_dir / TRAINING_STATE_DIR + write_json( + {"step": optimization_step, "interaction_step": interaction_step}, + training_state_dir / TRAINING_STEP, + ) # Update the "last" symlink update_last_checkpoint(checkpoint_dir) @@ -760,58 +647,6 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module): - """ - Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. - - This function sets up Adam optimizers for: - - The **actor network**, ensuring that only relevant parameters are optimized. - - The **critic ensemble**, which evaluates the value function. - - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. - - It also initializes a learning rate scheduler, though currently, it is set to `None`. - - NOTE: - - If the encoder is shared, its parameters are excluded from the actor's optimization process. - - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. - - Args: - cfg: Configuration object containing hyperparameters. - policy (nn.Module): The policy model containing the actor, critic, and temperature components. - - Returns: - Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: - A tuple containing: - - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. - - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. - - """ - optimizer_actor = torch.optim.Adam( - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=cfg.policy.actor_lr, - ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - - if cfg.policy.num_discrete_actions is not None: - optimizer_discrete_critic = torch.optim.Adam( - params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr - ) - optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) - lr_scheduler = None - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - if cfg.policy.num_discrete_actions is not None: - optimizers["discrete_critic"] = optimizer_discrete_critic - return optimizers, lr_scheduler - - # Training setup functions @@ -875,13 +710,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli def load_training_state( cfg: TrainRLServerPipelineConfig, optimizers: Optimizer | dict[str, Optimizer], + algorithm: RLAlgorithm | None = None, + device: str | torch.device = "cpu", ): """ - Loads the training state (optimizers, step count, etc.) from a checkpoint. + Loads the training state (optimizers, RNG, step + interaction step, and + algorithm-owned tensors) from the most recent checkpoint. Args: - cfg (TrainRLServerPipelineConfig): Training configuration - optimizers (Optimizer | dict): Optimizers to load state into + cfg: Training configuration. + optimizers: Optimizers to load state into. + algorithm: Algorithm whose state dict should be restored. + Required for full main-equivalent resume; + the policy itself is restored separately via ``make_policy``. + device: Device on which to place loaded algorithm tensors. Returns: tuple: (optimization_step, interaction_step) or (None, None) if not resuming @@ -890,20 +732,31 @@ def load_training_state( return None, None # Construct path to the last checkpoint directory - checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK logging.info(f"Loading training state from {checkpoint_dir}") try: - # Use the utility function from train_utils which loads the optimizer state - step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) + # Restore optimizers + RNG + step from the standard `training_state/` folder + step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None) - # Load interaction step separately from training_state.pt - training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") - interaction_step = 0 - if os.path.exists(training_state_path): - training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load - interaction_step = training_state.get("interaction_step", 0) + # Restore algorithm-owned tensors + if algorithm is not None: + algo_dir = checkpoint_dir / ALGORITHM_DIR + if algo_dir.is_dir(): + tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE)) + algorithm.load_state_dict(tensors, device=device) + logging.info(f"Loaded algorithm state from {algo_dir}") + else: + logging.warning( + f"No algorithm state found at {algo_dir}; " + "will keep their freshly-initialised values. Adam moments restored from the " + "old optimizer state may not match these reset parameters." + ) + + # Read interaction_step from the enriched training_step.json + training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP + interaction_step = int(load_json(training_step_path).get("interaction_step", 0)) logging.info(f"Resuming from step {step}, interaction step {interaction_step}") return step, interaction_step @@ -1016,33 +869,6 @@ def initialize_offline_replay_buffer( # Utilities/Helpers functions -def get_observation_features( - policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor -) -> tuple[torch.Tensor | None, torch.Tensor | None]: - """ - Get observation features from the policy encoder. It act as cache for the observation features. - when the encoder is frozen, the observation features are not updated. - We can save compute by caching the observation features. - - Args: - policy: The policy model - observations: The current observations - next_observations: The next observations - - Returns: - tuple: observation_features, next_observation_features - """ - - if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: - return None, None - - with torch.no_grad(): - observation_features = policy.actor.encoder.get_cached_image_features(observations) - next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations) - - return observation_features, next_observation_features - - def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: return cfg.policy.concurrency.learner == "threads" @@ -1093,19 +919,11 @@ def check_nan_in_transition( return nan_detected -def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): +def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None: logging.debug("[LEARNER] Pushing actor policy to the queue") # Create a dictionary to hold all the state dicts - state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")} - - # Add discrete critic if it exists - if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None: - state_dicts["discrete_critic"] = move_state_dict_to_device( - policy.discrete_critic.state_dict(), device="cpu" - ) - logging.debug("[LEARNER] Including discrete critic in state dict push") - + state_dicts = algorithm.get_weights() state_bytes = state_to_bytes(state_dicts) parameters_queue.put(state_bytes) @@ -1129,9 +947,8 @@ def process_transitions( transition_queue: Queue, replay_buffer: ReplayBuffer, offline_replay_buffer: ReplayBuffer, - device: str, dataset_repo_id: str | None, - shutdown_event: any, + shutdown_event: Any, # Event ): """Process all available transitions from the queue. @@ -1139,7 +956,6 @@ def process_transitions( transition_queue: Queue for receiving transitions from the actor replay_buffer: Replay buffer to add transitions to offline_replay_buffer: Offline replay buffer to add transitions to - device: Device to move transitions to dataset_repo_id: Repository ID for dataset shutdown_event: Event to signal shutdown """ @@ -1148,8 +964,6 @@ def process_transitions( transition_list = bytes_to_transitions(buffer=transition_list) for transition in transition_list: - transition = move_transition_to_device(transition=transition, device=device) - # Skip transitions with NaN values if check_nan_in_transition( observations=transition["state"], @@ -1163,7 +977,7 @@ def process_transitions( # Add to offline buffer if it's an intervention if dataset_repo_id is not None and transition.get("complementary_info", {}).get( - TeleopEvents.IS_INTERVENTION + TeleopEvents.IS_INTERVENTION.value ): offline_replay_buffer.add(**transition) @@ -1172,7 +986,7 @@ def process_interaction_messages( interaction_message_queue: Queue, interaction_step_shift: int, wandb_logger: WandBLogger | None, - shutdown_event: any, + shutdown_event: Any, # Event ) -> dict | None: """Process all available interaction messages from the queue. diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py index 4128cdf55..7a4df7136 100644 --- a/src/lerobot/rl/learner_service.py +++ b/src/lerobot/rl/learner_service.py @@ -18,17 +18,32 @@ import logging import time from multiprocessing import Event, Queue +from typing import TYPE_CHECKING -from lerobot.transport import services_pb2, services_pb2_grpc -from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks +from lerobot.utils.import_utils import _grpc_available from .queue import get_last_item_from_queue +if TYPE_CHECKING or _grpc_available: + import grpc + + from lerobot.transport import services_pb2, services_pb2_grpc + from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks + + _ServicerBase = services_pb2_grpc.LearnerServiceServicer +else: + grpc = None + services_pb2 = None + services_pb2_grpc = None + receive_bytes_in_chunks = None + send_bytes_in_chunks = None + _ServicerBase = object + MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 -class LearnerService(services_pb2_grpc.LearnerServiceServicer): +class LearnerService(_ServicerBase): """ Implementation of the LearnerService gRPC service This service is used to send parameters to the Actor and receive transitions and interactions from the Actor @@ -51,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): self.interaction_message_queue = interaction_message_queue self.queue_get_timeout = queue_get_timeout - def StreamParameters(self, request, context): # noqa: N802 + def StreamParameters( # noqa: N802 + self, request: "services_pb2.Empty", context: "grpc.ServicerContext" + ): # TODO: authorize the request logging.info("[LEARNER] Received request to stream parameters from the Actor") @@ -86,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): logging.info("[LEARNER] Stream parameters finished") return services_pb2.Empty() - def SendTransitions(self, request_iterator, _context): # noqa: N802 + def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802 # TODO: authorize the request logging.info("[LEARNER] Received request to receive transitions from the Actor") @@ -100,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): logging.debug("[LEARNER] Finished receiving transitions") return services_pb2.Empty() - def SendInteractions(self, request_iterator, _context): # noqa: N802 + def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802 # TODO: authorize the request logging.info("[LEARNER] Received request to receive interactions from the Actor") @@ -114,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): logging.debug("[LEARNER] Finished receiving interactions") return services_pb2.Empty() - def Ready(self, request, context): # noqa: N802 + def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802 return services_pb2.Empty() diff --git a/src/lerobot/rl/train_rl.py b/src/lerobot/rl/train_rl.py new file mode 100644 index 000000000..e5ae0f9f5 --- /dev/null +++ b/src/lerobot/rl/train_rl.py @@ -0,0 +1,50 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Top-level pipeline config for distributed RL training (actor / learner).""" + +from __future__ import annotations + +from dataclasses import dataclass + +from lerobot.configs.default import DatasetConfig +from lerobot.configs.train import TrainPipelineConfig + +from .algorithms.configs import RLAlgorithmConfig +from .algorithms.factory import make_algorithm_config +from .algorithms.sac import SACAlgorithmConfig # noqa: F401 + + +@dataclass(kw_only=True) +class TrainRLServerPipelineConfig(TrainPipelineConfig): + # NOTE: In RL, we don't need an offline dataset + # TODO: Make `TrainPipelineConfig.dataset` optional + dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional + + # Algorithm config. + algorithm: RLAlgorithmConfig | None = None + + # Data mixer strategy name. Currently supports "online_offline". + mixer: str = "online_offline" + # Fraction sampled from online replay when using OnlineOfflineMixer. + online_ratio: float = 0.5 + + def validate(self) -> None: + super().validate() + + if self.algorithm is None: + self.algorithm = make_algorithm_config("sac") + + if getattr(self.algorithm, "policy_config", None) is None: + self.algorithm.policy_config = self.policy diff --git a/src/lerobot/rl/trainer.py b/src/lerobot/rl/trainer.py new file mode 100644 index 000000000..65f00568e --- /dev/null +++ b/src/lerobot/rl/trainer.py @@ -0,0 +1,101 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from lerobot.types import BatchType + +from .algorithms.base import RLAlgorithm +from .algorithms.configs import TrainingStats +from .data_sources.data_mixer import DataMixer + + +class RLTrainer: + """Unified training step orchestrator. + + Holds the algorithm, a DataMixer, and an optional preprocessor. + """ + + def __init__( + self, + algorithm: RLAlgorithm, + data_mixer: DataMixer, + batch_size: int, + *, + preprocessor: Any | None = None, + ): + self.algorithm = algorithm + self.data_mixer = data_mixer + self.batch_size = batch_size + self._preprocessor = preprocessor + + self._iterator: Iterator[BatchType] | None = None + + self.algorithm.make_optimizers_and_scheduler() + + def _build_data_iterator(self) -> Iterator[BatchType]: + """Create a fresh algorithm-configured iterator (optionally preprocessed).""" + raw = self.algorithm.configure_data_iterator( + data_mixer=self.data_mixer, + batch_size=self.batch_size, + ) + if self._preprocessor is not None: + return _PreprocessedIterator(raw, self._preprocessor) + return raw + + def reset_data_iterator(self) -> None: + """Discard the current iterator so it will be rebuilt lazily next step.""" + self._iterator = None + + def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None: + """Swap the active data mixer, optionally resetting the iterator.""" + self.data_mixer = data_mixer + if reset: + self.reset_data_iterator() + + def training_step(self) -> TrainingStats: + """Run one training step (algorithm-agnostic).""" + if self._iterator is None: + self._iterator = self._build_data_iterator() + return self.algorithm.update(self._iterator) + + +def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType: + """Apply policy preprocessing to RL observations only.""" + observations = batch["state"] + next_observations = batch["next_state"] + batch["state"] = preprocessor.process_observation(observations) + batch["next_state"] = preprocessor.process_observation(next_observations) + + return batch + + +class _PreprocessedIterator: + """Iterator wrapper that preprocesses each sampled RL batch.""" + + __slots__ = ("_raw", "_preprocessor") + + def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None: + self._raw = raw_iterator + self._preprocessor = preprocessor + + def __iter__(self) -> _PreprocessedIterator: + return self + + def __next__(self) -> BatchType: + batch = next(self._raw) + return preprocess_rl_batch(self._preprocessor, batch) diff --git a/src/lerobot/robots/so_follower/robot_kinematic_processor.py b/src/lerobot/robots/so_follower/robot_kinematic_processor.py index 8114fdc2c..a95343b2d 100644 --- a/src/lerobot/robots/so_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so_follower/robot_kinematic_processor.py @@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep): speed_factor: A scaling factor to convert the normalized velocity command to a position change. clip_min: The minimum allowed gripper joint position. clip_max: The maximum allowed gripper joint position. - discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay). + discrete_gripper: If True, interpret the input as a discrete class index + {0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`. """ speed_factor: float = 20.0 @@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep): raise ValueError("Joints observation is require for computing robot kinematics") if self.discrete_gripper: - # Discrete gripper actions are in [0, 1, 2] - # 0: open, 1: close, 2: stay - # We need to shift them to [-1, 0, 1] and then scale them to clip_max - gripper_vel = (gripper_vel - 1) * self.clip_max + # Map discrete command {0=close, 1=stay, 2=open} -> signed velocity. + # Negation accounts for SO100 sign (joint position increases on close). + # 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open) + gripper_vel = -(gripper_vel - 1) * self.clip_max # Compute desired gripper position delta = gripper_vel * float(self.speed_factor) diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 6fc553d38..801789bcb 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -104,11 +104,14 @@ class KeyboardTeleop(Teleoperator): def _on_press(self, key): if hasattr(key, "char"): - self.event_queue.put((key.char, True)) + key = key.char + self.event_queue.put((key, True)) def _on_release(self, key): if hasattr(key, "char"): - self.event_queue.put((key.char, False)) + key = key.char + self.event_queue.put((key, False)) + if key == keyboard.Key.esc: logging.info("ESC pressed, disconnecting.") self.disconnect() @@ -204,8 +207,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): # this is useful for retrieving other events like interventions for RL, episode success, etc. self.misc_keys_queue.put(key) - self.current_pressed.clear() - action_dict = { "delta_x": delta_x, "delta_y": delta_y, @@ -256,6 +257,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): ] is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys) + self.current_pressed.clear() + # Check for episode control commands from misc_keys_queue terminate_episode = False success = False diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index c59cf4183..f0dd0da07 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -39,8 +39,8 @@ For more details, see the [Physical Intelligence π₀ blog post](https://www.ph π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training. For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05). -{% elif model_name == "sac" %} -[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments. +{% elif model_name == "gaussian_actor" %} +This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms. {% elif model_name == "reward_classifier" %} A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable. {% else %} diff --git a/src/lerobot/types.py b/src/lerobot/types.py index d9b8166c5..9de504870 100644 --- a/src/lerobot/types.py +++ b/src/lerobot/types.py @@ -40,6 +40,7 @@ PolicyAction = torch.Tensor RobotAction = dict[str, Any] EnvAction = np.ndarray RobotObservation = dict[str, Any] +BatchType = dict[str, Any] EnvTransition = TypedDict( diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43869228d..482394ff6 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints" LAST_CHECKPOINT_LINK = "last" PRETRAINED_MODEL_DIR = "pretrained_model" TRAINING_STATE_DIR = "training_state" +ALGORITHM_DIR = "algorithm" RNG_STATE = "rng_state.safetensors" TRAINING_STEP = "training_step.json" OPTIMIZER_STATE = "optimizer_state.safetensors" diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index bfa87fb86..6ba912bf5 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -132,6 +132,7 @@ _faker_available = is_package_available("faker") _pynput_available = is_package_available("pynput") _pygame_available = is_package_available("pygame") _qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils") +_grpc_available = is_package_available("grpcio", import_name="grpc") _wallx_deps_available = ( _transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available ) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_gaussian_actor_config.py similarity index 81% rename from tests/policies/test_sac_config.py rename to tests/policies/test_gaussian_actor_config.py index 724c331ff..004612374 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_gaussian_actor_config.py @@ -17,19 +17,19 @@ import pytest from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.sac.configuration_sac import ( +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import ( ActorLearnerConfig, ActorNetworkConfig, ConcurrencyConfig, CriticNetworkConfig, + GaussianActorConfig, PolicyConfig, - SACConfig, ) from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE -def test_sac_config_default_initialization(): - config = SACConfig() +def test_gaussian_actor_config_default_initialization(): + config = GaussianActorConfig() assert config.normalization_mapping == { "VISUAL": NormalizationMode.MEAN_STD, @@ -55,9 +55,6 @@ def test_sac_config_default_initialization(): # Basic parameters assert config.device == "cpu" assert config.storage_device == "cpu" - assert config.discount == 0.99 - assert config.temperature_init == 1.0 - assert config.num_critics == 2 # Architecture specifics assert config.vision_encoder_name is None @@ -66,6 +63,8 @@ def test_sac_config_default_initialization(): assert config.shared_encoder is True assert config.num_discrete_actions is None assert config.image_embedding_pooling_dim == 8 + assert config.state_encoder_hidden_dim == 256 + assert config.latent_dim == 256 # Training parameters assert config.online_steps == 1000000 @@ -73,20 +72,6 @@ def test_sac_config_default_initialization(): assert config.offline_buffer_capacity == 100000 assert config.async_prefetch is False assert config.online_step_before_learning == 100 - assert config.policy_update_freq == 1 - - # SAC algorithm parameters - assert config.num_subsample_critics is None - assert config.critic_lr == 3e-4 - assert config.actor_lr == 3e-4 - assert config.temperature_lr == 3e-4 - assert config.critic_target_update_weight == 0.005 - assert config.utd_ratio == 1 - assert config.state_encoder_hidden_dim == 256 - assert config.latent_dim == 256 - assert config.target_entropy is None - assert config.use_backup_entropy is True - assert config.grad_clip_norm == 40.0 # Dataset stats defaults expected_dataset_stats = { @@ -105,11 +90,6 @@ def test_sac_config_default_initialization(): } assert config.dataset_stats == expected_dataset_stats - # Critic network configuration - assert config.critic_network_kwargs.hidden_dims == [256, 256] - assert config.critic_network_kwargs.activate_final is True - assert config.critic_network_kwargs.final_activation is None - # Actor network configuration assert config.actor_network_kwargs.hidden_dims == [256, 256] assert config.actor_network_kwargs.activate_final is True @@ -135,7 +115,6 @@ def test_sac_config_default_initialization(): assert config.concurrency.learner == "threads" assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) - assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) assert isinstance(config.policy_kwargs, PolicyConfig) assert isinstance(config.actor_learner_config, ActorLearnerConfig) assert isinstance(config.concurrency, ConcurrencyConfig) @@ -175,22 +154,22 @@ def test_concurrency_config(): assert config.learner == "threads" -def test_sac_config_custom_initialization(): - config = SACConfig( +def test_gaussian_actor_config_custom_initialization(): + config = GaussianActorConfig( device="cpu", - discount=0.95, - temperature_init=0.5, - num_critics=3, + latent_dim=128, + state_encoder_hidden_dim=128, + num_discrete_actions=3, ) assert config.device == "cpu" - assert config.discount == 0.95 - assert config.temperature_init == 0.5 - assert config.num_critics == 3 + assert config.latent_dim == 128 + assert config.state_encoder_hidden_dim == 128 + assert config.num_discrete_actions == 3 def test_validate_features(): - config = SACConfig( + config = GaussianActorConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) @@ -198,7 +177,7 @@ def test_validate_features(): def test_validate_features_missing_observation(): - config = SACConfig( + config = GaussianActorConfig( input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) @@ -209,7 +188,7 @@ def test_validate_features_missing_observation(): def test_validate_features_missing_action(): - config = SACConfig( + config = GaussianActorConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) diff --git a/tests/policies/test_gaussian_actor_policy.py b/tests/policies/test_gaussian_actor_policy.py new file mode 100644 index 000000000..af802d26f --- /dev/null +++ b/tests/policies/test_gaussian_actor_policy.py @@ -0,0 +1,528 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 +from torch import Tensor, nn # noqa: E402 + +from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402 +from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402 +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402 +from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402 + +try: + import transformers # noqa: F401 + + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 42 + set_seed(seed) + + +def test_mlp_with_default_args(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + + x = torch.randn(10) + y = mlp(x) + assert y.shape == (256,) + + +def test_mlp_with_batch_dim(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + x = torch.randn(2, 10) + y = mlp(x) + assert y.shape == (2, 256) + + +def test_forward_with_empty_hidden_dims(): + mlp = MLP(input_dim=10, hidden_dims=[]) + x = torch.randn(1, 10) + assert mlp(x).shape == (1, 10) + + +def test_mlp_with_dropout(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 11) + + drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) + assert drop_out_layers_count == 2 + + +def test_mlp_with_custom_final_activation(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 256) + assert (y >= -1).all() and (y <= 1).all() + + +def test_gaussian_actor_policy_with_default_args(): + with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): + GaussianActorPolicy() + + +def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: + return { + OBS_STATE: torch.randn(batch_size, state_dim), + } + + +def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: + return { + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), + } + + +def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: + return torch.randn(batch_size, action_dim) + + +def create_default_train_batch( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + ACTION: create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_state(batch_size, state_dim), + "next_state": create_dummy_state(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_train_batch_with_visual_input( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + ACTION: create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_with_visual_input(batch_size, state_dim), + "next_state": create_dummy_with_visual_input(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + OBS_STATE: torch.randn(batch_size, state_dim), + } + + +def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + OBS_STATE: torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), + } + + +def create_default_config( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> GaussianActorConfig: + action_dim = continuous_action_dim + if has_discrete_action: + action_dim += 1 + + config = GaussianActorConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + dataset_stats={ + OBS_STATE: { + "min": [0.0] * state_dim, + "max": [1.0] * state_dim, + }, + ACTION: { + "min": [0.0] * continuous_action_dim, + "max": [1.0] * continuous_action_dim, + }, + }, + ) + config.validate_features() + return config + + +def create_config_with_visual_input( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> GaussianActorConfig: + config = create_default_config( + state_dim=state_dim, + continuous_action_dim=continuous_action_dim, + has_discrete_action=has_discrete_action, + ) + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + } + + config.state_encoder_hidden_dim = 32 + config.latent_dim = 32 + + config.validate_features() + return config + + +def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]: + """Helper to create policy + algorithm pair for tests that need critics.""" + policy = GaussianActorPolicy(config=config) + policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers_and_scheduler() + return algorithm, policy + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = GaussianActorPolicy(config=config) + policy.eval() + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + # squeeze(0) removes batch dim when batch_size==1 + assert selected_action.shape[-1] == action_dim + + +def test_gaussian_actor_policy_select_action_with_discrete(): + """select_action should return continuous + discrete actions.""" + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.num_discrete_actions = 3 + policy = GaussianActorPolicy(config=config) + policy.eval() + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=1, state_dim=10) + # Squeeze to unbatched (single observation) + observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()} + selected_action = policy.select_action(observation_batch) + assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = GaussianActorPolicy(config=config) + policy.eval() + + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + with torch.no_grad(): + output = policy.forward(batch) + assert "action" in output + assert "log_prob" in output + assert "action_mean" in output + assert output["action"].shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + algorithm, policy = _make_algorithm(config) + + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.item() is not None + assert actor_loss.shape == () + algorithm.optimizers["actor"].zero_grad() + actor_loss.backward() + algorithm.optimizers["actor"].step() + + temp_loss = algorithm._compute_loss_temperature(forward_batch) + assert temp_loss.item() is not None + assert temp_loss.shape == () + algorithm.optimizers["temperature"].zero_grad() + temp_loss.backward() + algorithm.optimizers["temperature"].step() + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + algorithm, policy = _make_algorithm(config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.item() is not None + assert actor_loss.shape == () + algorithm.optimizers["actor"].zero_grad() + actor_loss.backward() + algorithm.optimizers["actor"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape[-1] == action_dim + + +@pytest.mark.parametrize( + "batch_size,state_dim,action_dim,vision_encoder_name", + [(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], +) +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") +def test_gaussian_actor_policy_with_pretrained_encoder( + batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str +): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.vision_encoder_name = vision_encoder_name + algorithm, policy = _make_algorithm(config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.item() is not None + assert actor_loss.shape == () + + +def test_gaussian_actor_training_with_shared_encoder(): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.shared_encoder = True + + algorithm, policy = _make_algorithm(config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.shape == () + algorithm.optimizers["actor"].zero_grad() + actor_loss.backward() + algorithm.optimizers["actor"].step() + + +def test_gaussian_actor_training_with_discrete_critic(): + batch_size = 2 + continuous_action_dim = 9 + full_action_dim = continuous_action_dim + 1 + state_dim = 10 + config = create_config_with_visual_input( + state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True + ) + config.num_discrete_actions = 5 + + algorithm, policy = _make_algorithm(config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim + ) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch) + assert discrete_critic_loss.shape == () + algorithm.optimizers["discrete_critic"].zero_grad() + discrete_critic_loss.backward() + algorithm.optimizers["discrete_critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.shape == () + algorithm.optimizers["actor"].zero_grad() + actor_loss.backward() + algorithm.optimizers["actor"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + # Policy.select_action now handles both continuous + discrete + selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()}) + assert selected_action.shape[-1] == continuous_action_dim + 1 + + +def test_sac_algorithm_target_entropy(): + """Target entropy is an SAC hyperparameter and lives on the algorithm.""" + config = create_default_config(continuous_action_dim=10, state_dim=10) + algorithm, _ = _make_algorithm(config) + assert algorithm.target_entropy == -5.0 + + +def test_sac_algorithm_target_entropy_with_discrete_action(): + config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + config.num_discrete_actions = 5 + algorithm, _ = _make_algorithm(config) + assert algorithm.target_entropy == -3.5 + + +def test_sac_algorithm_temperature(): + import math + + config = create_default_config(continuous_action_dim=10, state_dim=10) + algo_config = SACAlgorithmConfig.from_policy_config(config) + policy = GaussianActorPolicy(config=config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + + assert algorithm.temperature == pytest.approx(1.0) + algorithm.log_alpha.data = torch.tensor([math.log(0.1)]) + assert algorithm.temperature == pytest.approx(0.1) + + +def test_sac_algorithm_update_target_network(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + algo_config = SACAlgorithmConfig.from_policy_config(config) + algo_config.critic_target_update_weight = 1.0 + policy = GaussianActorPolicy(config=config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + + for p in algorithm.critic_ensemble.parameters(): + p.data = torch.ones_like(p.data) + + algorithm._update_target_networks() + for p in algorithm.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)) + + +@pytest.mark.parametrize("num_critics", [1, 3]) +def test_sac_algorithm_with_critics_number_of_heads(num_critics: int): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + + policy = GaussianActorPolicy(config=config) + policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(config) + algo_config.num_critics = num_critics + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers_and_scheduler() + + assert len(algorithm.critic_ensemble.critics) == num_critics + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + +def test_gaussian_actor_policy_save_and_load(tmp_path): + """Test that the policy can be saved and loaded from pretrained.""" + root = tmp_path / "test_gaussian_actor_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = GaussianActorPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + with torch.no_grad(): + with seeded_context(12): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + actions = policy.select_action(observation_batch) + + with seeded_context(12): + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + loaded_actions = loaded_policy.select_action(loaded_observation_batch) + + assert torch.allclose(actions, loaded_actions) + + +def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path): + """Discrete critic should be saved/loaded as part of the policy.""" + root = tmp_path / "test_gaussian_actor_save_and_load_discrete" + + state_dim = 10 + action_dim = 6 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_discrete_actions = 3 + policy = GaussianActorPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + + loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + assert loaded_policy.discrete_critic is not None + dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 + + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py deleted file mode 100644 index 11499ce30..000000000 --- a/tests/policies/test_sac_policy.py +++ /dev/null @@ -1,546 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import pytest -import torch -from torch import Tensor, nn - -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.policies.sac.modeling_sac import MLP, SACPolicy -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE -from lerobot.utils.random_utils import seeded_context, set_seed - -try: - import transformers # noqa: F401 - - TRANSFORMERS_AVAILABLE = True -except ImportError: - TRANSFORMERS_AVAILABLE = False - - -@pytest.fixture(autouse=True) -def set_random_seed(): - seed = 42 - set_seed(seed) - - -def test_mlp_with_default_args(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256]) - - x = torch.randn(10) - y = mlp(x) - assert y.shape == (256,) - - -def test_mlp_with_batch_dim(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256]) - x = torch.randn(2, 10) - y = mlp(x) - assert y.shape == (2, 256) - - -def test_forward_with_empty_hidden_dims(): - mlp = MLP(input_dim=10, hidden_dims=[]) - x = torch.randn(1, 10) - assert mlp(x).shape == (1, 10) - - -def test_mlp_with_dropout(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) - x = torch.randn(1, 10) - y = mlp(x) - assert y.shape == (1, 11) - - drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) - assert drop_out_layers_count == 2 - - -def test_mlp_with_custom_final_activation(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) - x = torch.randn(1, 10) - y = mlp(x) - assert y.shape == (1, 256) - assert (y >= -1).all() and (y <= 1).all() - - -def test_sac_policy_with_default_args(): - with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): - SACPolicy() - - -def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: - return { - OBS_STATE: torch.randn(batch_size, state_dim), - } - - -def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: - return { - OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), - OBS_STATE: torch.randn(batch_size, state_dim), - } - - -def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: - return torch.randn(batch_size, action_dim) - - -def create_default_train_batch( - batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 -) -> dict[str, Tensor]: - return { - ACTION: create_dummy_action(batch_size, action_dim), - "reward": torch.randn(batch_size), - "state": create_dummy_state(batch_size, state_dim), - "next_state": create_dummy_state(batch_size, state_dim), - "done": torch.randn(batch_size), - } - - -def create_train_batch_with_visual_input( - batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 -) -> dict[str, Tensor]: - return { - ACTION: create_dummy_action(batch_size, action_dim), - "reward": torch.randn(batch_size), - "state": create_dummy_with_visual_input(batch_size, state_dim), - "next_state": create_dummy_with_visual_input(batch_size, state_dim), - "done": torch.randn(batch_size), - } - - -def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: - return { - OBS_STATE: torch.randn(batch_size, state_dim), - } - - -def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: - return { - OBS_STATE: torch.randn(batch_size, state_dim), - OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), - } - - -def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: - """Create optimizers for the SAC policy.""" - optimizer_actor = torch.optim.Adam( - # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=policy.config.actor_lr, - ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters(), - lr=policy.config.critic_lr, - ) - optimizer_temperature = torch.optim.Adam( - params=[policy.log_alpha], - lr=policy.config.critic_lr, - ) - - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - - if has_discrete_action: - optimizers["discrete_critic"] = torch.optim.Adam( - params=policy.discrete_critic.parameters(), - lr=policy.config.critic_lr, - ) - - return optimizers - - -def create_default_config( - state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False -) -> SACConfig: - action_dim = continuous_action_dim - if has_discrete_action: - action_dim += 1 - - config = SACConfig( - input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, - output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, - dataset_stats={ - OBS_STATE: { - "min": [0.0] * state_dim, - "max": [1.0] * state_dim, - }, - ACTION: { - "min": [0.0] * continuous_action_dim, - "max": [1.0] * continuous_action_dim, - }, - }, - ) - config.validate_features() - return config - - -def create_config_with_visual_input( - state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False -) -> SACConfig: - config = create_default_config( - state_dim=state_dim, - continuous_action_dim=continuous_action_dim, - has_discrete_action=has_discrete_action, - ) - config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) - config.dataset_stats[OBS_IMAGE] = { - "mean": torch.randn(3, 1, 1), - "std": torch.randn(3, 1, 1), - } - - # Let make tests a little bit faster - config.state_encoder_hidden_dim = 32 - config.latent_dim = 32 - - config.validate_features() - return config - - -@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): - batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) - config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) - - policy = SACPolicy(config=config) - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() - - policy.eval() - with torch.no_grad(): - observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) - - -@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - policy = SACPolicy(config=config) - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() - - policy.eval() - with torch.no_grad(): - observation_batch = create_observation_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim - ) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) - - -# Let's check best candidates for pretrained encoders -@pytest.mark.parametrize( - "batch_size,state_dim,action_dim,vision_encoder_name", - [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], -) -@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) -def test_sac_policy_with_pretrained_encoder( - batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str -): - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.vision_encoder_name = vision_encoder_name - policy = SACPolicy(config=config) - policy.train() - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - -def test_sac_policy_with_shared_encoder(): - batch_size = 2 - action_dim = 10 - state_dim = 10 - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.shared_encoder = True - - policy = SACPolicy(config=config) - policy.train() - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - -def test_sac_policy_with_discrete_critic(): - batch_size = 2 - continuous_action_dim = 9 - full_action_dim = continuous_action_dim + 1 # the last action is discrete - state_dim = 10 - config = create_config_with_visual_input( - state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True - ) - - num_discrete_actions = 5 - config.num_discrete_actions = num_discrete_actions - - policy = SACPolicy(config=config) - policy.train() - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy, has_discrete_action=True) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] - assert discrete_critic_loss.item() is not None - assert discrete_critic_loss.shape == () - discrete_critic_loss.backward() - optimizers["discrete_critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - policy.eval() - with torch.no_grad(): - observation_batch = create_observation_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim - ) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, full_action_dim) - - discrete_actions = selected_action[:, -1].long() - discrete_action_values = set(discrete_actions.tolist()) - - assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( - f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" - ) - - -def test_sac_policy_with_default_entropy(): - config = create_default_config(continuous_action_dim=10, state_dim=10) - policy = SACPolicy(config=config) - assert policy.target_entropy == -5.0 - - -def test_sac_policy_default_target_entropy_with_discrete_action(): - config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) - policy = SACPolicy(config=config) - assert policy.target_entropy == -3.0 - - -def test_sac_policy_with_predefined_entropy(): - config = create_default_config(state_dim=10, continuous_action_dim=6) - config.target_entropy = -3.5 - - policy = SACPolicy(config=config) - assert policy.target_entropy == pytest.approx(-3.5) - - -def test_sac_policy_update_temperature(): - """Test that temperature property is always in sync with log_alpha.""" - config = create_default_config(continuous_action_dim=10, state_dim=10) - policy = SACPolicy(config=config) - - assert policy.temperature == pytest.approx(1.0) - policy.log_alpha.data = torch.tensor([math.log(0.1)]) - # Temperature property automatically reflects log_alpha changes - assert policy.temperature == pytest.approx(0.1) - - -def test_sac_policy_update_target_network(): - config = create_default_config(state_dim=10, continuous_action_dim=6) - config.critic_target_update_weight = 1.0 - - policy = SACPolicy(config=config) - policy.train() - - for p in policy.critic_ensemble.parameters(): - p.data = torch.ones_like(p.data) - - policy.update_target_networks() - for p in policy.critic_target.parameters(): - assert torch.allclose(p.data, torch.ones_like(p.data)), ( - f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" - ) - - -@pytest.mark.parametrize("num_critics", [1, 3]) -def test_sac_policy_with_critics_number_of_heads(num_critics: int): - batch_size = 2 - action_dim = 10 - state_dim = 10 - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.num_critics = num_critics - - policy = SACPolicy(config=config) - policy.train() - - assert len(policy.critic_ensemble.critics) == num_critics - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - -def test_sac_policy_save_and_load(tmp_path): - root = tmp_path / "test_sac_save_and_load" - - state_dim = 10 - action_dim = 10 - batch_size = 2 - - config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) - policy = SACPolicy(config=config) - policy.eval() - policy.save_pretrained(root) - loaded_policy = SACPolicy.from_pretrained(root, config=config) - loaded_policy.eval() - - batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) - - with torch.no_grad(): - with seeded_context(12): - # Collect policy values before saving - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - - observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - actions = policy.select_action(observation_batch) - - with seeded_context(12): - # Collect policy values after loading - loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] - loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] - loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] - - loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - loaded_actions = loaded_policy.select_action(loaded_observation_batch) - - assert policy.state_dict().keys() == loaded_policy.state_dict().keys() - for k in policy.state_dict(): - assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) - - # Compare values before and after saving and loading - # They should be the same - assert torch.allclose(cirtic_loss, loaded_cirtic_loss) - assert torch.allclose(actor_loss, loaded_actor_loss) - assert torch.allclose(temperature_loss, loaded_temperature_loss) - assert torch.allclose(actions, loaded_actions) diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_gaussian_actor_processor.py similarity index 89% rename from tests/processor/test_sac_processor.py rename to tests/processor/test_gaussian_actor_processor.py index a1a4b285d..2429bc23a 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_gaussian_actor_processor.py @@ -21,8 +21,8 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig +from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors from lerobot.processor import ( AddBatchDimensionProcessorStep, DataProcessorPipeline, @@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE def create_default_config(): """Create a default SAC configuration for testing.""" - config = SACConfig() + config = GaussianActorConfig() config.input_features = { OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), } @@ -66,7 +66,7 @@ def test_make_sac_processor_basic(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -88,12 +88,12 @@ def test_make_sac_processor_basic(): assert isinstance(postprocessor.steps[1], DeviceProcessorStep) -def test_sac_processor_normalization_modes(): +def test_gaussian_actor_processor_normalization_modes(): """Test that SAC processor correctly handles different normalization modes.""" config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_sac_processor_cuda(): +def test_gaussian_actor_processor_cuda(): """Test SAC processor with CUDA device.""" config = create_default_config() config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -153,13 +153,13 @@ def test_sac_processor_cuda(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_sac_processor_accelerate_scenario(): +def test_gaussian_actor_processor_accelerate_scenario(): """Test SAC processor in simulated Accelerate scenario.""" config = create_default_config() config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") -def test_sac_processor_multi_gpu(): +def test_gaussian_actor_processor_multi_gpu(): """Test SAC processor with multi-GPU setup.""" config = create_default_config() config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu(): assert processed[TransitionKey.ACTION.value].device == device -def test_sac_processor_without_stats(): +def test_gaussian_actor_processor_without_stats(): """Test SAC processor creation without dataset statistics.""" config = create_default_config() - preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None) + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -226,12 +226,12 @@ def test_sac_processor_without_stats(): assert processed is not None -def test_sac_processor_save_and_load(): +def test_gaussian_actor_processor_save_and_load(): """Test saving and loading SAC processor.""" config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -257,14 +257,14 @@ def test_sac_processor_save_and_load(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_sac_processor_mixed_precision(): +def test_gaussian_actor_processor_mixed_precision(): """Test SAC processor with mixed precision.""" config = create_default_config() config.device = "cuda" stats = create_default_stats() # Create processor - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision(): assert processed[TransitionKey.ACTION.value].dtype == torch.float16 -def test_sac_processor_batch_data(): +def test_gaussian_actor_processor_batch_data(): """Test SAC processor with batched data.""" config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -329,12 +329,12 @@ def test_sac_processor_batch_data(): assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5) -def test_sac_processor_edge_cases(): +def test_gaussian_actor_processor_edge_cases(): """Test SAC processor with edge cases.""" config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_pre_post_processors( + preprocessor, postprocessor = make_gaussian_actor_pre_post_processors( config, stats, ) @@ -358,13 +358,13 @@ def test_sac_processor_edge_cases(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_sac_processor_bfloat16_device_float32_normalizer(): +def test_gaussian_actor_processor_bfloat16_device_float32_normalizer(): """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" config = create_default_config() config.device = "cuda" stats = create_default_stats() - preprocessor, _ = make_sac_pre_post_processors( + preprocessor, _ = make_gaussian_actor_pre_post_processors( config, stats, ) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index cd5c75005..e046adb0d 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict(): override_normalizer.stats[key][stat_name], original_stats[key][stat_name] ), f"Stats for {key}.{stat_name} should not match original stats" - # Verify that _tensor_stats are also correctly set to match the override stats + # Verify that _tensor_stats values match the override stats + # Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats expected_tensor_stats = to_tensor(override_stats) for key in expected_tensor_stats: for stat_name in expected_tensor_stats[key]: if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor): torch.testing.assert_close( - override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + override_normalizer._tensor_stats[key][stat_name].squeeze(), + expected_tensor_stats[key][stat_name].squeeze(), ) @@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally(): # Stats should now match the original stats (normal behavior) # Check that all keys and values match assert set(new_normalizer.stats.keys()) == set(original_stats.keys()) + # Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats, + # so we squeeze before comparing values. for key in original_stats: assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys()) for stat_name in original_stats[key]: - np.testing.assert_allclose( - new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6 - ) + actual = new_normalizer.stats[key][stat_name] + expected = original_stats[key][stat_name] + if hasattr(actual, "squeeze"): + actual = actual.squeeze() + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) def test_stats_explicit_provided_flag_detection(): @@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict(): assert ACTION in new_normalizer.stats # Check that values are correct (converted back from tensors) - np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) - np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) + # Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0]) diff --git a/tests/rewards/test_modeling_classifier.py b/tests/rewards/test_modeling_classifier.py index 08f6121a1..043dbb660 100644 --- a/tests/rewards/test_modeling_classifier.py +++ b/tests/rewards/test_modeling_classifier.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @@ -36,9 +35,6 @@ def test_classifier_output(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_binary_classifier_with_default_params(): from lerobot.rewards.classifier.modeling_classifier import Classifier @@ -80,9 +76,6 @@ def test_binary_classifier_with_default_params(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_multiclass_classifier(): from lerobot.rewards.classifier.modeling_classifier import Classifier @@ -122,9 +115,6 @@ def test_multiclass_classifier(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_default_device(): from lerobot.rewards.classifier.modeling_classifier import Classifier @@ -141,9 +131,6 @@ def test_default_device(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_explicit_device_setup(): from lerobot.rewards.classifier.modeling_classifier import Classifier diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 3978dfffd..e0df14e62 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -22,12 +22,14 @@ import pytest import torch pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") +pytest.importorskip("grpc") from torch.multiprocessing import Event, Queue -from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.utils.constants import OBS_STR +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig +from lerobot.rl.train_rl import TrainRLServerPipelineConfig +from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.transition import Transition from tests.utils import skip_if_package_missing @@ -79,7 +81,7 @@ def cfg(): port = find_free_port() - policy_cfg = SACConfig() + policy_cfg = GaussianActorConfig() policy_cfg.actor_learner_config.learner_host = "127.0.0.1" policy_cfg.actor_learner_config.learner_port = port policy_cfg.concurrency.actor = "threads" @@ -299,3 +301,164 @@ def test_end_to_end_parameters_flow(cfg, data_size): assert received_params.keys() == input_params.keys() for key in input_params: assert torch.allclose(received_params[key], input_params[key]) + + +def test_learner_algorithm_wiring(): + """Verify that make_algorithm constructs an SACAlgorithm from config, + make_optimizers_and_scheduler() creates the right optimizers, update() works, and + get_weights() output is serializable.""" + from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy + from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig + from lerobot.transport.utils import state_to_bytes + + state_dim = 10 + action_dim = 6 + + sac_cfg = GaussianActorConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + ) + sac_cfg.validate_features() + + policy = GaussianActorPolicy(config=sac_cfg) + policy.train() + + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + assert isinstance(algorithm, SACAlgorithm) + + optimizers = algorithm.make_optimizers_and_scheduler() + assert "actor" in optimizers + assert "critic" in optimizers + assert "temperature" in optimizers + + batch_size = 4 + + def batch_iterator(): + while True: + yield { + ACTION: torch.randn(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": {OBS_STATE: torch.randn(batch_size, state_dim)}, + "next_state": {OBS_STATE: torch.randn(batch_size, state_dim)}, + "done": torch.zeros(batch_size), + "complementary_info": {}, + } + + stats = algorithm.update(batch_iterator()) + assert "loss_critic" in stats.losses + + # get_weights -> state_to_bytes round-trip + weights = algorithm.get_weights() + assert len(weights) > 0 + serialized = state_to_bytes(weights) + assert isinstance(serialized, bytes) + assert len(serialized) > 0 + + # RLTrainer with DataMixer + from lerobot.rl.buffer import ReplayBuffer + from lerobot.rl.data_sources import OnlineOfflineMixer + from lerobot.rl.trainer import RLTrainer + + replay_buffer = ReplayBuffer( + capacity=50, + device="cpu", + state_keys=[OBS_STATE], + storage_device="cpu", + use_drq=False, + ) + for _ in range(50): + replay_buffer.add( + state={OBS_STATE: torch.randn(state_dim)}, + action=torch.randn(action_dim), + reward=1.0, + next_state={OBS_STATE: torch.randn(state_dim)}, + done=False, + truncated=False, + ) + data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None) + trainer = RLTrainer( + algorithm=algorithm, + data_mixer=data_mixer, + batch_size=batch_size, + ) + trainer_stats = trainer.training_step() + assert "loss_critic" in trainer_stats.losses + + +def test_initial_and_periodic_weight_push_consistency(): + """Both initial and periodic weight pushes should use algorithm.get_weights() + and produce identical structures.""" + from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy + from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithmConfig + from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes + + state_dim = 10 + action_dim = 6 + sac_cfg = GaussianActorConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + ) + sac_cfg.validate_features() + + policy = GaussianActorPolicy(config=sac_cfg) + policy.train() + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + algorithm.make_optimizers_and_scheduler() + + # Simulate initial push (same code path the learner now uses) + initial_weights = algorithm.get_weights() + initial_bytes = state_to_bytes(initial_weights) + + # Simulate periodic push + periodic_weights = algorithm.get_weights() + periodic_bytes = state_to_bytes(periodic_weights) + + initial_decoded = bytes_to_state_dict(initial_bytes) + periodic_decoded = bytes_to_state_dict(periodic_bytes) + + assert initial_decoded.keys() == periodic_decoded.keys() + + +def test_actor_side_algorithm_select_action_and_load_weights(): + """Simulate actor: create algorithm without optimizers, select_action, load_weights.""" + from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy + from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig + + state_dim = 10 + action_dim = 6 + sac_cfg = GaussianActorConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + ) + sac_cfg.validate_features() + + # Actor side: no optimizers + policy = GaussianActorPolicy(config=sac_cfg) + policy.eval() + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + # select_action should work + obs = {OBS_STATE: torch.randn(state_dim)} + action = policy.select_action(obs) + assert action.shape == (action_dim,) + + # Simulate receiving weights from learner + fake_weights = algorithm.get_weights() + algorithm.load_weights(fake_weights, device="cpu") diff --git a/tests/rl/test_data_mixer.py b/tests/rl/test_data_mixer.py new file mode 100644 index 000000000..b153498d7 --- /dev/null +++ b/tests/rl/test_data_mixer.py @@ -0,0 +1,89 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer).""" + +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 + +from lerobot.rl.buffer import ReplayBuffer # noqa: E402 +from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402 +from lerobot.utils.constants import OBS_STATE # noqa: E402 + + +def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer: + buf = ReplayBuffer( + capacity=capacity, + device="cpu", + state_keys=[OBS_STATE], + storage_device="cpu", + use_drq=False, + ) + for i in range(capacity): + buf.add( + state={OBS_STATE: torch.randn(state_dim)}, + action=torch.randn(2), + reward=1.0, + next_state={OBS_STATE: torch.randn(state_dim)}, + done=bool(i % 10 == 9), + truncated=False, + ) + return buf + + +def test_online_only_mixer_sample(): + """OnlineOfflineMixer with no offline buffer returns online-only batches.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5) + batch = mixer.sample(batch_size=8) + assert batch["state"][OBS_STATE].shape[0] == 8 + assert batch["action"].shape[0] == 8 + assert batch["reward"].shape[0] == 8 + + +def test_online_only_mixer_ratio_one(): + """OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0) + batch = mixer.sample(batch_size=10) + assert batch["state"][OBS_STATE].shape[0] == 10 + + +def test_online_offline_mixer_sample(): + """OnlineOfflineMixer with two buffers returns concatenated batches.""" + online = _make_buffer(capacity=50) + offline = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer( + online_buffer=online, + offline_buffer=offline, + online_ratio=0.5, + ) + batch = mixer.sample(batch_size=10) + assert batch["state"][OBS_STATE].shape[0] == 10 + assert batch["action"].shape[0] == 10 + # 5 from online, 5 from offline (approx) + assert batch["reward"].shape[0] == 10 + + +def test_online_offline_mixer_iterator(): + """get_iterator yields batches of the requested size.""" + buf = _make_buffer(capacity=50) + mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None) + it = mixer.get_iterator(batch_size=4, async_prefetch=False) + batch1 = next(it) + batch2 = next(it) + assert batch1["state"][OBS_STATE].shape[0] == 4 + assert batch2["state"][OBS_STATE].shape[0] == 4 diff --git a/tests/rl/test_queue.py b/tests/rl/test_queue.py index cf3d6cdca..77936d269 100644 --- a/tests/rl/test_queue.py +++ b/tests/rl/test_queue.py @@ -20,7 +20,7 @@ from queue import Queue import pytest -pytest.importorskip("grpc") +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402 diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py new file mode 100644 index 000000000..2d77ae9ba --- /dev/null +++ b/tests/rl/test_sac_algorithm.py @@ -0,0 +1,606 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the RL algorithm abstraction and SACAlgorithm implementation.""" + +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 + +from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402 +from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy # noqa: E402 +from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats # noqa: E402 +from lerobot.rl.algorithms.factory import make_algorithm # noqa: E402 +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 + +# --------------------------------------------------------------------------- +# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py) +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def set_random_seed(): + set_seed(42) + + +def _make_sac_config( + state_dim: int = 10, + action_dim: int = 6, + num_discrete_actions: int | None = None, + with_images: bool = False, +) -> GaussianActorConfig: + config = GaussianActorConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + dataset_stats={ + OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, + ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, + }, + num_discrete_actions=num_discrete_actions, + ) + if with_images: + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { + "mean": torch.randn(3, 1, 1).tolist(), + "std": torch.randn(3, 1, 1).abs().tolist(), + } + config.latent_dim = 32 + config.state_encoder_hidden_dim = 32 + config.validate_features() + return config + + +def _make_algorithm( + state_dim: int = 10, + action_dim: int = 6, + utd_ratio: int = 1, + policy_update_freq: int = 1, + num_discrete_actions: int | None = None, + with_images: bool = False, +) -> tuple[SACAlgorithm, GaussianActorPolicy]: + sac_cfg = _make_sac_config( + state_dim=state_dim, + action_dim=action_dim, + num_discrete_actions=num_discrete_actions, + with_images=with_images, + ) + policy = GaussianActorPolicy(config=sac_cfg) + policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + algo_config.utd_ratio = utd_ratio + algo_config.policy_update_freq = policy_update_freq + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers_and_scheduler() + return algorithm, policy + + +def _make_batch( + batch_size: int = 4, + state_dim: int = 10, + action_dim: int = 6, + with_images: bool = False, +) -> dict: + obs = {OBS_STATE: torch.randn(batch_size, state_dim)} + next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)} + if with_images: + obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84) + next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84) + return { + ACTION: torch.randn(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": obs, + "next_state": next_obs, + "done": torch.zeros(batch_size), + "complementary_info": {}, + } + + +def _batch_iterator(**batch_kwargs): + """Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator).""" + while True: + yield _make_batch(**batch_kwargs) + + +# =========================================================================== +# Registry / config tests +# =========================================================================== + + +def test_sac_algorithm_config_registered(): + """SACAlgorithmConfig should be discoverable through the registry.""" + assert "sac" in RLAlgorithmConfig.get_known_choices() + cls = RLAlgorithmConfig.get_choice_class("sac") + assert cls is SACAlgorithmConfig + + +def test_sac_algorithm_config_from_policy_config(): + """from_policy_config embeds the policy config and uses SAC defaults.""" + sac_cfg = _make_sac_config() + algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) + assert algo_cfg.policy_config is sac_cfg + assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs + # Defaults come from SACAlgorithmConfig, not from the policy config. + assert algo_cfg.utd_ratio == 1 + assert algo_cfg.policy_update_freq == 1 + assert algo_cfg.grad_clip_norm == 40.0 + assert algo_cfg.actor_lr == 3e-4 + + +# =========================================================================== +# TrainingStats tests +# =========================================================================== + + +def test_training_stats_defaults(): + stats = TrainingStats() + assert stats.losses == {} + assert stats.grad_norms == {} + assert stats.extra == {} + + +# =========================================================================== +# get_weights +# =========================================================================== + + +def test_get_weights_returns_policy_state_dict(): + algorithm, policy = _make_algorithm() + weights = algorithm.get_weights() + assert "policy" in weights + actor_state_dict = policy.actor.state_dict() + for key in actor_state_dict: + assert key in weights["policy"] + assert torch.equal(weights["policy"][key].cpu(), actor_state_dict[key].cpu()) + + +def test_get_weights_includes_discrete_critic_when_present(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + weights = algorithm.get_weights() + assert "discrete_critic" in weights + assert len(weights["discrete_critic"]) > 0 + + +def test_get_weights_excludes_discrete_critic_when_absent(): + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + assert "discrete_critic" not in weights + + +def test_get_weights_are_on_cpu(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + weights = algorithm.get_weights() + for group_name, state_dict in weights.items(): + for key, tensor in state_dict.items(): + assert tensor.device == torch.device("cpu"), f"{group_name}/{key} is not on CPU" + + +# =========================================================================== +# select_action (lives on the policy, not the algorithm) +# =========================================================================== + + +def test_select_action_returns_correct_shape(): + action_dim = 6 + _, policy = _make_algorithm(state_dim=10, action_dim=action_dim) + policy.eval() + obs = {OBS_STATE: torch.randn(10)} + action = policy.select_action(obs) + assert action.shape == (action_dim,) + + +def test_select_action_with_discrete_critic(): + continuous_dim = 5 + _, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3) + policy.eval() + obs = {OBS_STATE: torch.randn(10)} + action = policy.select_action(obs) + assert action.shape == (continuous_dim + 1,) + + +# =========================================================================== +# update (single batch, utd_ratio=1) +# =========================================================================== + + +def test_update_returns_training_stats(): + algorithm, _ = _make_algorithm() + stats = algorithm.update(_batch_iterator()) + assert isinstance(stats, TrainingStats) + assert "loss_critic" in stats.losses + assert isinstance(stats.losses["loss_critic"], float) + + +def test_update_populates_actor_and_temperature_losses(): + """With policy_update_freq=1 and step 0, actor/temperature should be updated.""" + algorithm, _ = _make_algorithm(policy_update_freq=1) + stats = algorithm.update(_batch_iterator()) + assert "loss_actor" in stats.losses + assert "loss_temperature" in stats.losses + assert "temperature" in stats.extra + + +@pytest.mark.parametrize("policy_update_freq", [2, 3]) +def test_update_skips_actor_at_non_update_steps(policy_update_freq): + """Actor/temperature should only update when optimization_step % freq == 0.""" + algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq) + it = _batch_iterator() + + # Step 0: should update actor + stats_0 = algorithm.update(it) + assert "loss_actor" in stats_0.losses + + # Step 1: should NOT update actor + stats_1 = algorithm.update(it) + assert "loss_actor" not in stats_1.losses + + +def test_update_increments_optimization_step(): + algorithm, _ = _make_algorithm() + it = _batch_iterator() + assert algorithm.optimization_step == 0 + algorithm.update(it) + assert algorithm.optimization_step == 1 + algorithm.update(it) + assert algorithm.optimization_step == 2 + + +def test_update_with_discrete_critic(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete + assert "loss_discrete_critic" in stats.losses + assert "discrete_critic" in stats.grad_norms + + +# =========================================================================== +# update with UTD ratio > 1 +# =========================================================================== + + +@pytest.mark.parametrize("utd_ratio", [2, 4]) +def test_update_with_utd_ratio(utd_ratio): + algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) + stats = algorithm.update(_batch_iterator()) + assert isinstance(stats, TrainingStats) + assert "loss_critic" in stats.losses + assert algorithm.optimization_step == 1 + + +def test_update_utd_ratio_pulls_utd_batches(): + """next(batch_iterator) should be called exactly utd_ratio times.""" + utd_ratio = 3 + algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) + + call_count = 0 + + def counting_iterator(): + nonlocal call_count + while True: + call_count += 1 + yield _make_batch() + + algorithm.update(counting_iterator()) + assert call_count == utd_ratio + + +def test_update_utd_ratio_3_critic_warmup_changes_weights(): + """With utd_ratio=3, critic weights should change after update (3 critic steps).""" + algorithm, policy = _make_algorithm(utd_ratio=3) + + critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()} + + algorithm.update(_batch_iterator()) + + changed = False + for n, p in algorithm.critic_ensemble.named_parameters(): + if not torch.equal(p, critic_params_before[n]): + changed = True + break + assert changed, "Critic weights should have changed after UTD update" + + +# =========================================================================== +# get_observation_features +# =========================================================================== + + +def test_get_observation_features_returns_none_without_frozen_encoder(): + algorithm, _ = _make_algorithm(with_images=False) + obs = {OBS_STATE: torch.randn(4, 10)} + next_obs = {OBS_STATE: torch.randn(4, 10)} + feat, next_feat = algorithm.get_observation_features(obs, next_obs) + assert feat is None + assert next_feat is None + + +# =========================================================================== +# optimization_step setter +# =========================================================================== + + +def test_optimization_step_can_be_set_for_resume(): + algorithm, _ = _make_algorithm() + algorithm.optimization_step = 100 + assert algorithm.optimization_step == 100 + + +# =========================================================================== +# make_algorithm factory +# =========================================================================== + + +def test_make_algorithm_returns_sac_for_sac_policy(): + sac_cfg = _make_sac_config() + policy = GaussianActorPolicy(config=sac_cfg) + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + +def test_make_optimizers_creates_expected_keys(): + """make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers.""" + sac_cfg = _make_sac_config() + policy = GaussianActorPolicy(config=sac_cfg) + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + optimizers = algorithm.make_optimizers_and_scheduler() + assert "actor" in optimizers + assert "critic" in optimizers + assert "temperature" in optimizers + assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values()) + assert algorithm.get_optimizers() is optimizers + + +def test_actor_side_no_optimizers(): + """Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called.""" + sac_cfg = _make_sac_config() + policy = GaussianActorPolicy(config=sac_cfg) + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.optimizers == {} + + +def test_make_algorithm_uses_sac_algorithm_defaults(): + """make_algorithm populates SACAlgorithmConfig with its own defaults.""" + sac_cfg = _make_sac_config() + policy = GaussianActorPolicy(config=sac_cfg) + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) + assert algorithm.config.utd_ratio == 1 + assert algorithm.config.policy_update_freq == 1 + assert algorithm.config.grad_clip_norm == 40.0 + + +def test_unknown_algorithm_name_raises_in_registry(): + """The ChoiceRegistry is the source of truth for unknown algorithm names.""" + with pytest.raises(KeyError): + RLAlgorithmConfig.get_choice_class("unknown_algo") + + +# =========================================================================== +# load_weights (round-trip with get_weights) +# =========================================================================== + + +def test_load_weights_round_trip(): + """get_weights -> load_weights should restore identical parameters on a fresh policy.""" + algo_src, _ = _make_algorithm(state_dim=10, action_dim=6) + algo_src.update(_batch_iterator()) + + sac_cfg = _make_sac_config(state_dim=10, action_dim=6) + policy_dst = GaussianActorPolicy(config=sac_cfg) + algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config) + + weights = algo_src.get_weights() + algo_dst.load_weights(weights, device="cpu") + + dst_actor_state_dict = algo_dst.policy.actor.state_dict() + for key, tensor in weights["policy"].items(): + assert torch.equal( + dst_actor_state_dict[key].cpu(), + tensor.cpu(), + ), f"Policy param '{key}' mismatch after load_weights" + + +def test_load_weights_round_trip_with_discrete_critic(): + algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + algo_src.update(_batch_iterator(action_dim=7)) + + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + policy_dst = GaussianActorPolicy(config=sac_cfg) + algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config) + + weights = algo_src.get_weights() + algo_dst.load_weights(weights, device="cpu") + + assert "discrete_critic" in weights + assert len(weights["discrete_critic"]) > 0 + dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict() + for key, tensor in weights["discrete_critic"].items(): + assert torch.equal( + dst_discrete_critic_state_dict[key].cpu(), + tensor.cpu(), + ), f"Discrete critic param '{key}' mismatch after load_weights" + + +def test_load_weights_ignores_missing_discrete_critic(): + """load_weights should not fail when weights lack discrete_critic on a non-discrete policy.""" + algorithm, _ = _make_algorithm() + weights = algorithm.get_weights() + algorithm.load_weights(weights, device="cpu") + + +def test_actor_side_weight_sync_with_discrete_critic(): + """End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``.""" + # Learner side: train the source algorithm so its weights diverge from init. + algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + algo_src.update(_batch_iterator(action_dim=7)) + weights = algo_src.get_weights() + + # Actor side: fresh policy + fresh algorithm holding it. + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + policy_actor = GaussianActorPolicy(config=sac_cfg) + algo_actor = SACAlgorithm( + policy=policy_actor, + config=SACAlgorithmConfig.from_policy_config(sac_cfg), + ) + + # Snapshot initial actor state for the "did it change?" assertion below. + initial_discrete_critic_state_dict = { + k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items() + } + + algo_actor.load_weights(weights, device="cpu") + + # Actor weights match the learner's exported actor state dict. + actor_state_dict = policy_actor.actor.state_dict() + for key, tensor in weights["policy"].items(): + assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), ( + f"Actor param '{key}' not synced by algorithm.load_weights" + ) + + # Discrete critic weights match the learner's exported discrete critic. + discrete_critic_state_dict = policy_actor.discrete_critic.state_dict() + for key, tensor in weights["discrete_critic"].items(): + assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), ( + f"Discrete critic param '{key}' not synced by algorithm.load_weights" + ) + + # Sanity: the discrete critic actually changed (otherwise the sync is trivial). + changed = any( + not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key]) + for key in initial_discrete_critic_state_dict + if key in discrete_critic_state_dict + ) + assert changed, "Discrete critic weights did not change between init and after sync" + + +# =========================================================================== +# TrainingStats generic losses dict +# =========================================================================== + + +def test_training_stats_generic_losses(): + stats = TrainingStats( + losses={"loss_bc": 0.5, "loss_q": 1.2}, + extra={"temperature": 0.1}, + ) + assert stats.losses["loss_bc"] == 0.5 + assert stats.losses["loss_q"] == 1.2 + assert stats.extra["temperature"] == 0.1 + + +# =========================================================================== +# Registry-driven make_algorithm +# =========================================================================== + + +def test_make_algorithm_builds_sac(): + """make_algorithm should look up the SAC class from the registry and instantiate it.""" + sac_cfg = _make_sac_config() + algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + algo_config.utd_ratio = 2 + policy = GaussianActorPolicy(config=sac_cfg) + + algorithm = make_algorithm(cfg=algo_config, policy=policy) + assert isinstance(algorithm, SACAlgorithm) + assert algorithm.config.utd_ratio == 2 + + +# =========================================================================== +# state_dict / load_state_dict (algorithm-side resume) +# =========================================================================== + + +def test_state_dict_contains_algorithm_owned_tensors(): + """state_dict should pack critics, target networks, and log_alpha (no encoder bloat).""" + algorithm, _ = _make_algorithm() + sd = algorithm.state_dict() + + assert "log_alpha" in sd + assert any(k.startswith("critic_ensemble.") for k in sd) + assert any(k.startswith("critic_target.") for k in sd) + # encoder weights live on the policy and must not be duplicated here. + assert not any(".encoder." in k for k in sd) + + +def test_state_dict_includes_discrete_critic_target_when_present(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + sd = algorithm.state_dict() + assert any(k.startswith("discrete_critic_target.") for k in sd) + + +def test_load_state_dict_round_trip_restores_critics_and_log_alpha(): + """state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly.""" + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + src_policy = GaussianActorPolicy(config=sac_cfg) + src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg)) + src.make_optimizers_and_scheduler() + # Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete). + src.update(_batch_iterator(action_dim=7)) + src.update(_batch_iterator(action_dim=7)) + + dst_policy = GaussianActorPolicy(config=sac_cfg) + dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg)) + dst.make_optimizers_and_scheduler() + + src_sd = src.state_dict() + dst.load_state_dict(src_sd) + dst_sd = dst.state_dict() + + assert set(dst_sd) == set(src_sd) + for key in src_sd: + assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip" + + +def test_load_state_dict_preserves_log_alpha_parameter_identity(): + """The temperature optimizer holds a reference to log_alpha; identity must survive load.""" + algorithm, _ = _make_algorithm() + log_alpha_id_before = id(algorithm.log_alpha) + optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) + assert log_alpha_id_before == optimizer_param_id + + new_state = algorithm.state_dict() + new_state["log_alpha"] = torch.tensor([0.42]) + algorithm.load_state_dict(new_state) + + assert id(algorithm.log_alpha) == log_alpha_id_before + assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before + assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42])) + + +def test_save_pretrained_round_trip_via_disk(tmp_path): + """End-to-end: save_pretrained -> from_pretrained restores tensors and config.""" + sac_cfg = _make_sac_config() + src_policy = GaussianActorPolicy(config=sac_cfg) + src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg)) + src.make_optimizers_and_scheduler() + src.update(_batch_iterator()) + + save_dir = tmp_path / "algorithm" + src.save_pretrained(save_dir) + assert (save_dir / "model.safetensors").is_file() + assert (save_dir / "config.json").is_file() + + dst_policy = GaussianActorPolicy(config=sac_cfg) + dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy) + + src_sd = src.state_dict() + dst_sd = dst.state_dict() + assert set(src_sd) == set(dst_sd) + for key in src_sd: + assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip" diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py new file mode 100644 index 000000000..b15d4393b --- /dev/null +++ b/tests/rl/test_trainer.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 +from torch import Tensor # noqa: E402 + +from lerobot.rl.algorithms.base import RLAlgorithm # noqa: E402 +from lerobot.rl.algorithms.configs import TrainingStats # noqa: E402 +from lerobot.rl.trainer import RLTrainer # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402 + + +class _DummyRLAlgorithmConfig: + """Dummy config for testing.""" + + +class _DummyRLAlgorithm(RLAlgorithm): + config_class = _DummyRLAlgorithmConfig + name = "dummy_rl_algorithm" + + def __init__(self): + self.configure_calls = 0 + self.update_calls = 0 + + def select_action(self, observation: dict[str, Tensor]) -> Tensor: + return torch.zeros(1) + + def configure_data_iterator( + self, + data_mixer, + batch_size: int, + *, + async_prefetch: bool = True, + queue_size: int = 2, + ): + self.configure_calls += 1 + return data_mixer.get_iterator( + batch_size=batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + def make_optimizers_and_scheduler(self): + return {} + + def update(self, batch_iterator): + self.update_calls += 1 + _ = next(batch_iterator) + return TrainingStats(losses={"dummy": 1.0}) + + def load_weights(self, weights, device="cpu") -> None: + _ = (weights, device) + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state_dict, device="cpu") -> None: + _ = (state_dict, device) + + +class _SimpleMixer: + def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2): + _ = (async_prefetch, queue_size) + while True: + yield { + "state": {OBS_STATE: torch.randn(batch_size, 3)}, + ACTION: torch.randn(batch_size, 2), + "reward": torch.randn(batch_size), + "next_state": {OBS_STATE: torch.randn(batch_size, 3)}, + "done": torch.zeros(batch_size), + "truncated": torch.zeros(batch_size), + "complementary_info": None, + } + + +def test_trainer_lazy_iterator_lifecycle_and_reset(): + algo = _DummyRLAlgorithm() + mixer = _SimpleMixer() + trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4) + + # First call builds iterator once. + trainer.training_step() + assert algo.configure_calls == 1 + assert algo.update_calls == 1 + + # Second call reuses existing iterator. + trainer.training_step() + assert algo.configure_calls == 1 + assert algo.update_calls == 2 + + # Explicit reset forces lazy rebuild on next step. + trainer.reset_data_iterator() + trainer.training_step() + assert algo.configure_calls == 2 + assert algo.update_calls == 3 + + +def test_trainer_set_data_mixer_resets_by_default(): + algo = _DummyRLAlgorithm() + mixer_a = _SimpleMixer() + mixer_b = _SimpleMixer() + trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2) + + trainer.training_step() + assert algo.configure_calls == 1 + + trainer.set_data_mixer(mixer_b, reset=True) + trainer.training_step() + assert algo.configure_calls == 2 + + +def test_algorithm_optimization_step_contract_defaults(): + algo = _DummyRLAlgorithm() + assert algo.optimization_step == 0 + algo.optimization_step = 11 + assert algo.optimization_step == 11 diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index 65b24aac4..1ede0bfeb 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -22,7 +22,7 @@ from unittest.mock import patch import pytest -pytest.importorskip("grpc") +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") from lerobot.utils.process import ProcessSignalHandler # noqa: E402 diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 1b2af39f1..e6517596f 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -19,7 +19,6 @@ from collections.abc import Callable import pytest -pytest.importorskip("grpc") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") import torch # noqa: E402 diff --git a/uv.lock b/uv.lock index f09c4a64f..28b906d89 100644 --- a/uv.lock +++ b/uv.lock @@ -2849,10 +2849,16 @@ hardware = [ { name = "pyserial" }, ] hilserl = [ + { name = "av" }, + { name = "datasets" }, { name = "grpcio" }, { name = "gym-hil" }, + { name = "jsonlines" }, + { name = "pandas" }, { name = "placo" }, { name = "protobuf" }, + { name = "pyarrow" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, { name = "transformers" }, ] hopejr = [ @@ -3069,6 +3075,7 @@ requires-dist = [ { name = "lerobot", extras = ["dataset"], marker = "extra == 'aloha'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'core-scripts'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'dataset-viz'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'hilserl'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'libero'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'metaworld'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'pusht'" },