[HIL-SERL] Migrate threading to multiprocessing (#759)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-03-05 17:19:31 +07:00
committed by AdilZouitine
parent 38f5fa4523
commit db78fee9de
14 changed files with 900 additions and 492 deletions
+240 -200
View File
@@ -15,15 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import queue
import shutil
import time
from pprint import pformat
from threading import Lock, Thread
import signal
from threading import Event
from concurrent.futures import ThreadPoolExecutor
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from lerobot.scripts.server.utils import setup_process_handlers
import grpc
# Import generated stubs
@@ -52,19 +55,19 @@ from lerobot.common.utils.utils import (
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_transition_to_device,
move_state_dict_to_device,
bytes_to_transitions,
state_to_bytes,
bytes_to_python_object,
)
from lerobot.scripts.server import learner_service
logging.basicConfig(level=logging.INFO)
transition_queue = queue.Queue()
interaction_message_queue = queue.Queue()
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
@@ -195,67 +198,96 @@ def get_observation_features(
return observation_features, next_observation_features
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.learner == "threads"
def start_learner_threads(
cfg: DictConfig,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict,
policy: SACPolicy,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
out_dir: str,
shutdown_event: any, # Event,
) -> None:
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
# Create multiprocessing queues
transition_queue = Queue()
interaction_message_queue = Queue()
parameters_queue = Queue()
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from torch.multiprocessing import Process
concurrency_entity = Process
communication_process = concurrency_entity(
target=start_learner_server,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
parameters_queue,
transition_queue,
interaction_message_queue,
shutdown_event,
cfg,
),
daemon=True,
)
communication_process.start()
transition_thread.start()
add_actor_information_and_train(
cfg,
logger,
out_dir,
shutdown_event,
transition_queue,
interaction_message_queue,
parameters_queue,
)
logging.info("[LEARNER] Training process stopped")
logging.info("[LEARNER] Closing queues")
transition_queue.close()
interaction_message_queue.close()
parameters_queue.close()
communication_process.join()
logging.info("[LEARNER] Communication process joined")
logging.info("[LEARNER] join queues")
transition_queue.cancel_join_thread()
interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[LEARNER] queues closed")
def start_learner_server(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
cfg: DictConfig,
):
if not use_threads(cfg):
# We need init logging for MP separataly
init_logging()
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
# Return back for MP
setup_process_handlers(False)
service = learner_service.LearnerService(
shutdown_event,
policy,
policy_lock,
parameters_queue,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
)
server = start_learner_server(service, host, port)
shutdown_event.wait()
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
transition_thread.join()
logging.info("[LEARNER] Transition thread stopped")
def start_learner_server(
service: learner_service.LearnerService,
host="0.0.0.0",
port=50051,
) -> grpc.server:
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
@@ -263,15 +295,23 @@ def start_learner_server(
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
],
)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}")
server.start()
logging.info("[LEARNER] gRPC server started")
return server
shutdown_event.wait()
logging.info("[LEARNER] Stopping gRPC server...")
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
def check_nan_in_transition(
@@ -287,19 +327,21 @@ def check_nan_in_transition(
logging.error("actions contains NaN values")
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
state_bytes = state_to_bytes(state_dict)
parameters_queue.put(state_bytes)
def add_actor_information_and_train(
cfg,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
out_dir: str,
shutdown_event: any, # Event,
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
@@ -322,17 +364,73 @@ def add_actor_information_and_train(
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
batch_size (int): The number of transitions to sample per training step.
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates.
logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
shutdown_event (Event | None): Event to signal shutdown.
out_dir (str): The output directory for storing training checkpoints and logs.
shutdown_event (Event): Event to signal shutdown.
transition_queue (Queue): Queue for receiving transitions from the actor.
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
logging.info("Initializing policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
@@ -345,33 +443,39 @@ def add_actor_information_and_train(
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
saved_data = False
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
while not transition_queue.empty():
logging.debug("[LEARNER] Waiting for transitions")
while not transition_queue.empty() and not shutdown_event.is_set():
transition_list = transition_queue.get()
transition_list = bytes_to_transitions(transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_intervention"):
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
logging.debug("[LEARNER] Received transitions")
logging.debug("[LEARNER] Waiting for interactions")
while not interaction_message_queue.empty() and not shutdown_event.is_set():
interaction_message = interaction_message_queue.get()
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
# logging.info(f"Interaction message: {interaction_message}")
logging.debug("[LEARNER] Received interactions")
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@@ -392,19 +496,18 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
@@ -427,46 +530,51 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["loss_temperature"] = loss_temperature.item()
if (
time.time() - last_time_policy_pushed
> cfg.actor_learner_config.policy_parameters_push_frequency
):
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
@@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy_lock = Lock()
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
shutdown_event = Event()
def signal_handler(signum, frame):
print(
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
)
shutdown_event.set()
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
shutdown_event = setup_process_handlers(use_threads(cfg))
start_learner_threads(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
out_dir,
shutdown_event,
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def train_cli(cfg: dict):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
logging.info("[LEARNER] train_cli finished")
if __name__ == "__main__":
train_cli()
logging.info("[LEARNER] main finished")