#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import socket import threading import time import pytest import torch from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package def create_test_transitions(count: int = 3) -> list[Transition]: """Create test transitions for integration testing.""" transitions = [] for i in range(count): transition = Transition( state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(i == count - 1), # Last transition is done truncated=torch.tensor(False), next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, ) transitions.append(transition) return transitions def create_test_interactions(count: int = 3) -> list[dict]: """Create test interactions for integration testing.""" interactions = [] for i in range(count): interaction = { "episode_reward": 10.0 + i * 5, "step": i * 100, "policy_fps": 30.0 + i, "intervention_rate": 0.1 * i, "episode_length": 200 + i * 50, } interactions.append(interaction) return interactions def find_free_port(): """Finds a free port on the local machine.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port s.listen(1) port = s.getsockname()[1] return port @pytest.fixture def cfg(): cfg = TrainRLServerPipelineConfig() port = find_free_port() policy_cfg = SACConfig() policy_cfg.actor_learner_config.learner_host = "127.0.0.1" policy_cfg.actor_learner_config.learner_port = port policy_cfg.concurrency.actor = "threads" policy_cfg.concurrency.learner = "threads" policy_cfg.actor_learner_config.queue_get_timeout = 0.1 cfg.policy = policy_cfg return cfg @require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_end_to_end_transitions_flow(cfg): from lerobot.rl.actor import ( establish_learner_connection, learner_service_client, push_transitions_to_transport_queue, send_transitions, ) from lerobot.rl.learner import start_learner from lerobot.transport.utils import bytes_to_transitions from tests.transport.test_transport_utils import assert_transitions_equal """Test complete transitions flow from actor to learner.""" transitions_actor_queue = Queue() transitions_learner_queue = Queue() interactions_queue = Queue() parameters_queue = Queue() shutdown_event = Event() learner_thread = threading.Thread( target=start_learner, args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg), ) learner_thread.start() policy_cfg = cfg.policy learner_client, channel = learner_service_client( host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port ) assert establish_learner_connection(learner_client, shutdown_event, attempts=5) send_transitions_thread = threading.Thread( target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel) ) send_transitions_thread.start() input_transitions = create_test_transitions(count=5) push_transitions_to_transport_queue(input_transitions, transitions_actor_queue) # Wait for learner to start time.sleep(0.1) shutdown_event.set() # Wait for learner to receive transitions learner_thread.join() send_transitions_thread.join() channel.close() received_transitions = [] while not transitions_learner_queue.empty(): received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get())) assert len(received_transitions) == len(input_transitions) for i, transition in enumerate(received_transitions): assert_transitions_equal(transition, input_transitions[i]) @require_package("grpcio", "grpc") @pytest.mark.timeout(10) def test_end_to_end_interactions_flow(cfg): from lerobot.rl.actor import ( establish_learner_connection, learner_service_client, send_interactions, ) from lerobot.rl.learner import start_learner from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes """Test complete interactions flow from actor to learner.""" # Queues for actor-learner communication interactions_actor_queue = Queue() interactions_learner_queue = Queue() # Other queues required by the learner parameters_queue = Queue() transitions_learner_queue = Queue() shutdown_event = Event() # Start the learner in a separate thread learner_thread = threading.Thread( target=start_learner, args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg), ) learner_thread.start() # Establish connection from actor to learner policy_cfg = cfg.policy learner_client, channel = learner_service_client( host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port ) assert establish_learner_connection(learner_client, shutdown_event, attempts=5) # Start the actor's interaction sending process in a separate thread send_interactions_thread = threading.Thread( target=send_interactions, args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel), ) send_interactions_thread.start() # Create and push test interactions to the actor's queue input_interactions = create_test_interactions(count=5) for interaction in input_interactions: interactions_actor_queue.put(python_object_to_bytes(interaction)) # Wait for the communication to happen time.sleep(0.1) # Signal shutdown and wait for threads to complete shutdown_event.set() learner_thread.join() send_interactions_thread.join() channel.close() # Verify that the learner received the interactions received_interactions = [] while not interactions_learner_queue.empty(): received_interactions.append(bytes_to_python_object(interactions_learner_queue.get())) assert len(received_interactions) == len(input_interactions) # Sort by a unique key to handle potential reordering in queues received_interactions.sort(key=lambda x: x["step"]) input_interactions.sort(key=lambda x: x["step"]) for received, expected in zip(received_interactions, input_interactions, strict=False): assert received == expected @require_package("grpcio", "grpc") @pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.timeout(10) def test_end_to_end_parameters_flow(cfg, data_size): from lerobot.rl.actor import establish_learner_connection, learner_service_client, receive_policy from lerobot.rl.learner import start_learner from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes """Test complete parameter flow from learner to actor, with small and large data.""" # Actor's local queue to receive params parameters_actor_queue = Queue() # Learner's queue to send params from parameters_learner_queue = Queue() # Other queues required by the learner transitions_learner_queue = Queue() interactions_learner_queue = Queue() shutdown_event = Event() # Start the learner in a separate thread learner_thread = threading.Thread( target=start_learner, args=( parameters_learner_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg, ), ) learner_thread.start() # Establish connection from actor to learner policy_cfg = cfg.policy learner_client, channel = learner_service_client( host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port ) assert establish_learner_connection(learner_client, shutdown_event, attempts=5) # Start the actor's parameter receiving process in a separate thread receive_params_thread = threading.Thread( target=receive_policy, args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel), ) receive_params_thread.start() # Create test parameters based on parametrization if data_size == "small": input_params = {"layer.weight": torch.randn(128, 64)} else: # "large" # CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking input_params = {"large_layer.weight": torch.randn(1024, 1024)} # Simulate learner having new parameters to send parameters_learner_queue.put(state_to_bytes(input_params)) # Wait for the actor to receive the parameters time.sleep(0.1) # Signal shutdown and wait for threads to complete shutdown_event.set() learner_thread.join() receive_params_thread.join() channel.close() # Verify that the actor received the parameters correctly received_params = bytes_to_state_dict(parameters_actor_queue.get()) assert received_params.keys() == input_params.keys() for key in input_params: assert torch.allclose(received_params[key], input_params[key]) # --------------------------------------------------------------------------- # Regression test: learner algorithm integration (no gRPC required) # --------------------------------------------------------------------------- def test_learner_algorithm_wiring(): """Verify that make_algorithm constructs an SACAlgorithm from config, make_optimizers_and_scheduler() creates the right optimizers, update() works, and get_weights() output is serializable.""" from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.rl.algorithms.factory import make_algorithm from lerobot.rl.algorithms.sac import SACAlgorithm from lerobot.transport.utils import state_to_bytes state_dim = 10 action_dim = 6 sac_cfg = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, dataset_stats={ OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, use_torch_compile=False, ) sac_cfg.validate_features() policy = SACPolicy(config=sac_cfg) policy.train() algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") assert isinstance(algorithm, SACAlgorithm) optimizers = algorithm.make_optimizers_and_scheduler() assert "actor" in optimizers assert "critic" in optimizers assert "temperature" in optimizers batch_size = 4 def batch_iterator(): while True: yield { ACTION: torch.randn(batch_size, action_dim), "reward": torch.randn(batch_size), "state": {OBS_STATE: torch.randn(batch_size, state_dim)}, "next_state": {OBS_STATE: torch.randn(batch_size, state_dim)}, "done": torch.zeros(batch_size), "complementary_info": {}, } stats = algorithm.update(batch_iterator()) assert "critic" in stats.losses # get_weights -> state_to_bytes round-trip weights = algorithm.get_weights() assert len(weights) > 0 serialized = state_to_bytes(weights) assert isinstance(serialized, bytes) assert len(serialized) > 0 # RLTrainer with DataMixer from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.data_sources import OnlineOfflineMixer from lerobot.rl.trainer import RLTrainer replay_buffer = ReplayBuffer( capacity=50, device="cpu", state_keys=[OBS_STATE], storage_device="cpu", use_drq=False, ) for _ in range(50): replay_buffer.add( state={OBS_STATE: torch.randn(state_dim)}, action=torch.randn(action_dim), reward=1.0, next_state={OBS_STATE: torch.randn(state_dim)}, done=False, truncated=False, ) data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None) trainer = RLTrainer( algorithm=algorithm, data_mixer=data_mixer, batch_size=batch_size, ) trainer_stats = trainer.training_step() assert "critic" in trainer_stats.losses def test_initial_and_periodic_weight_push_consistency(): """Both initial and periodic weight pushes should use algorithm.get_weights() and produce identical structures.""" from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.rl.algorithms.factory import make_algorithm from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes state_dim = 10 action_dim = 6 sac_cfg = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, dataset_stats={ OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, use_torch_compile=False, ) sac_cfg.validate_features() policy = SACPolicy(config=sac_cfg) policy.train() algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm.make_optimizers_and_scheduler() # Simulate initial push (same code path the learner now uses) initial_weights = algorithm.get_weights() initial_bytes = state_to_bytes(initial_weights) # Simulate periodic push periodic_weights = algorithm.get_weights() periodic_bytes = state_to_bytes(periodic_weights) initial_decoded = bytes_to_state_dict(initial_bytes) periodic_decoded = bytes_to_state_dict(periodic_bytes) assert initial_decoded.keys() == periodic_decoded.keys() def test_actor_side_algorithm_select_action_and_load_weights(): """Simulate actor: create algorithm without optimizers, select_action, load_weights.""" from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.rl.algorithms.factory import make_algorithm from lerobot.rl.algorithms.sac import SACAlgorithm state_dim = 10 action_dim = 6 sac_cfg = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, dataset_stats={ OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, use_torch_compile=False, ) sac_cfg.validate_features() # Actor side: no optimizers policy = SACPolicy(config=sac_cfg) policy.eval() algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") assert isinstance(algorithm, SACAlgorithm) assert algorithm.optimizers == {} # select_action should work obs = {OBS_STATE: torch.randn(state_dim)} action = policy.select_action(obs) assert action.shape == (action_dim,) # Simulate receiving weights from learner fake_weights = algorithm.get_weights() algorithm.load_weights(fake_weights, device="cpu")