Files
lerobot/tests/rl/test_actor_learner.py
T
Khalil Meftah e963e5a0c4 RL stack refactoring (#3075)
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring

* chore: clarify torch.compile disabled note in SACAlgorithm

* fix(teleop): keyboard EE teleop not registering special keys and losing intervention state

Fixes #2345

Co-authored-by: jpizarrom <jpizarrom@gmail.com>

* fix: remove leftover normalization calls from reward classifier predict_reward

Fixes #2355

* fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample()

* refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

* perf: remove redundant CPU→GPU→CPU transition move in learner

* Fix: add kwargs in reward classifier __init__()

* fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer

* fix: add try/finally to control_loop to ensure image writer cleanup on exit

* fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error

* fix: skip tests that require grpc if not available

* fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

* fix(tests): skip tests that require grpc if not available

* refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

* fix(config): update vision encoder model name to lerobot/resnet10

* fix(sac): clarify torch.compile status

* refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

* refactor(sac): simplify optimizer return structure

* perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

* refactor(sac): decouple algorithm hyperparameters from policy config

* update losses names in tests

* fix docstring

* remove unused type alias

* fix test for flat dict structure

* refactor(policies): rename policies/sac → policies/gaussian_actor

* refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

* perf(observation_processor): add CUDA support for image processing

* fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline

(cherry picked from commit 9c2af818ff)

* fix(rl): add time limit processor to environment pipeline

(cherry picked from commit cd105f65cb)

* fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100

(cherry picked from commit 494f469a2b)

* fix(rl): update neutral gripper action

(cherry picked from commit 9c9064e5be)

* fix(rl): merge environment and action-processor info in transition processing

(cherry picked from commit 30e1886b64)

* fix(rl): mirror gym_manipulator in actor

(cherry picked from commit d2a046dfc5)

* fix(rl): postprocess action in actor

(cherry picked from commit c2556439e5)

* fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)

* fix(rl): enhance intervention handling in actor and learner

(cherry picked from commit ef8bfffbd7)

* Revert "perf(observation_processor): add CUDA support for image processing"

This reverts commit 38b88c414c.

* refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

* refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

* refactor(rl): add type property to RLAlgorithmConfig for better clarity

* refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility

* refactor(tests): remove grpc import checks from test files for cleaner code

* fix(tests): gate RL tests on the `datasets` extra

* refactor: simplify docstrings for clarity and conciseness across multiple files

* fix(rl): update gripper position key and handle action absence during reset

* fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

* refactor: clean up import statements

* chore: address reviewer comments

* chore: improve visual stats reshaping logic and update docstring for clarity

* refactor: enforce mandatory config_class and name attributes in RLAlgorithm

* refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

* refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

* refactor: add require_package calls for grpcio and gym-hil in relevant modules

* refactor(rl): move grpcio guards to runtime entry points

* feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.

* refactor(rl): move actor weight-sync wire format from policy to algorithm

* refactor(rl): update type hints for learner and actor functions

* refactor(rl): hoist grpcio guard to module top in actor/learner

* chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>

* update uv.lock

* chore(doc): update doc

---------

Co-authored-by: jpizarrom <jpizarrom@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-12 15:49:54 +02:00

465 lines
16 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import threading
import time
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("grpc")
from torch.multiprocessing import Event, Queue
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
def create_test_transitions(count: int = 3) -> list[Transition]:
"""Create test transitions for integration testing."""
transitions = []
for i in range(count):
transition = Transition(
state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5),
reward=torch.tensor(1.0 + i),
done=torch.tensor(i == count - 1), # Last transition is done
truncated=torch.tensor(False),
next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
)
transitions.append(transition)
return transitions
def create_test_interactions(count: int = 3) -> list[dict]:
"""Create test interactions for integration testing."""
interactions = []
for i in range(count):
interaction = {
"episode_reward": 10.0 + i * 5,
"step": i * 100,
"policy_fps": 30.0 + i,
"intervention_rate": 0.1 * i,
"episode_length": 200 + i * 50,
}
interactions.append(interaction)
return interactions
def find_free_port():
"""Finds a free port on the local machine."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port
s.listen(1)
port = s.getsockname()[1]
return port
@pytest.fixture
def cfg():
cfg = TrainRLServerPipelineConfig()
port = find_free_port()
policy_cfg = GaussianActorConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
policy_cfg.concurrency.learner = "threads"
policy_cfg.actor_learner_config.queue_get_timeout = 0.1
cfg.policy = policy_cfg
return cfg
@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg):
from lerobot.rl.actor import (
establish_learner_connection,
learner_service_client,
push_transitions_to_transport_queue,
send_transitions,
)
from lerobot.rl.learner import start_learner
from lerobot.transport.utils import bytes_to_transitions
from tests.transport.test_transport_utils import assert_transitions_equal
"""Test complete transitions flow from actor to learner."""
transitions_actor_queue = Queue()
transitions_learner_queue = Queue()
interactions_queue = Queue()
parameters_queue = Queue()
shutdown_event = Event()
learner_thread = threading.Thread(
target=start_learner,
args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg),
)
learner_thread.start()
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
send_transitions_thread = threading.Thread(
target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel)
)
send_transitions_thread.start()
input_transitions = create_test_transitions(count=5)
push_transitions_to_transport_queue(input_transitions, transitions_actor_queue)
# Wait for learner to start
time.sleep(0.1)
shutdown_event.set()
# Wait for learner to receive transitions
learner_thread.join()
send_transitions_thread.join()
channel.close()
received_transitions = []
while not transitions_learner_queue.empty():
received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get()))
assert len(received_transitions) == len(input_transitions)
for i, transition in enumerate(received_transitions):
assert_transitions_equal(transition, input_transitions[i])
@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg):
from lerobot.rl.actor import (
establish_learner_connection,
learner_service_client,
send_interactions,
)
from lerobot.rl.learner import start_learner
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
"""Test complete interactions flow from actor to learner."""
# Queues for actor-learner communication
interactions_actor_queue = Queue()
interactions_learner_queue = Queue()
# Other queues required by the learner
parameters_queue = Queue()
transitions_learner_queue = Queue()
shutdown_event = Event()
# Start the learner in a separate thread
learner_thread = threading.Thread(
target=start_learner,
args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg),
)
learner_thread.start()
# Establish connection from actor to learner
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
# Start the actor's interaction sending process in a separate thread
send_interactions_thread = threading.Thread(
target=send_interactions,
args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel),
)
send_interactions_thread.start()
# Create and push test interactions to the actor's queue
input_interactions = create_test_interactions(count=5)
for interaction in input_interactions:
interactions_actor_queue.put(python_object_to_bytes(interaction))
# Wait for the communication to happen
time.sleep(0.1)
# Signal shutdown and wait for threads to complete
shutdown_event.set()
learner_thread.join()
send_interactions_thread.join()
channel.close()
# Verify that the learner received the interactions
received_interactions = []
while not interactions_learner_queue.empty():
received_interactions.append(bytes_to_python_object(interactions_learner_queue.get()))
assert len(received_interactions) == len(input_interactions)
# Sort by a unique key to handle potential reordering in queues
received_interactions.sort(key=lambda x: x["step"])
input_interactions.sort(key=lambda x: x["step"])
for received, expected in zip(received_interactions, input_interactions, strict=False):
assert received == expected
@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size):
from lerobot.rl.actor import establish_learner_connection, learner_service_client, receive_policy
from lerobot.rl.learner import start_learner
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
"""Test complete parameter flow from learner to actor, with small and large data."""
# Actor's local queue to receive params
parameters_actor_queue = Queue()
# Learner's queue to send params from
parameters_learner_queue = Queue()
# Other queues required by the learner
transitions_learner_queue = Queue()
interactions_learner_queue = Queue()
shutdown_event = Event()
# Start the learner in a separate thread
learner_thread = threading.Thread(
target=start_learner,
args=(
parameters_learner_queue,
transitions_learner_queue,
interactions_learner_queue,
shutdown_event,
cfg,
),
)
learner_thread.start()
# Establish connection from actor to learner
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
# Start the actor's parameter receiving process in a separate thread
receive_params_thread = threading.Thread(
target=receive_policy,
args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel),
)
receive_params_thread.start()
# Create test parameters based on parametrization
if data_size == "small":
input_params = {"layer.weight": torch.randn(128, 64)}
else: # "large"
# CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking
input_params = {"large_layer.weight": torch.randn(1024, 1024)}
# Simulate learner having new parameters to send
parameters_learner_queue.put(state_to_bytes(input_params))
# Wait for the actor to receive the parameters
time.sleep(0.1)
# Signal shutdown and wait for threads to complete
shutdown_event.set()
learner_thread.join()
receive_params_thread.join()
channel.close()
# Verify that the actor received the parameters correctly
received_params = bytes_to_state_dict(parameters_actor_queue.get())
assert received_params.keys() == input_params.keys()
for key in input_params:
assert torch.allclose(received_params[key], input_params[key])
def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.transport.utils import state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
assert "critic" in optimizers
assert "temperature" in optimizers
batch_size = 4
def batch_iterator():
while True:
yield {
ACTION: torch.randn(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"done": torch.zeros(batch_size),
"complementary_info": {},
}
stats = algorithm.update(batch_iterator())
assert "loss_critic" in stats.losses
# get_weights -> state_to_bytes round-trip
weights = algorithm.get_weights()
assert len(weights) > 0
serialized = state_to_bytes(weights)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
# RLTrainer with DataMixer
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.rl.trainer import RLTrainer
replay_buffer = ReplayBuffer(
capacity=50,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for _ in range(50):
replay_buffer.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(action_dim),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=False,
truncated=False,
)
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
)
trainer_stats = trainer.training_step()
assert "loss_critic" in trainer_stats.losses
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
algorithm.make_optimizers_and_scheduler()
# Simulate initial push (same code path the learner now uses)
initial_weights = algorithm.get_weights()
initial_bytes = state_to_bytes(initial_weights)
# Simulate periodic push
periodic_weights = algorithm.get_weights()
periodic_bytes = state_to_bytes(periodic_weights)
initial_decoded = bytes_to_state_dict(initial_bytes)
periodic_decoded = bytes_to_state_dict(periodic_bytes)
assert initial_decoded.keys() == periodic_decoded.keys()
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
# Actor side: no optimizers
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
# select_action should work
obs = {OBS_STATE: torch.randn(state_dim)}
action = policy.select_action(obs)
assert action.shape == (action_dim,)
# Simulate receiving weights from learner
fake_weights = algorithm.get_weights()
algorithm.load_weights(fake_weights, device="cpu")