fix(rl): ensure queue and process cleanup on abnormal exit (#3063)

Wrap the main execution in actor_cli and start_learner_threads with
try/finally so that queues are closed and processes are joined even
when an unhandled exception occurs. Previously, exceptions in
act_with_policy or add_actor_information_and_train would skip all
cleanup code, leaking GPU/CPU resources.

Also sets the shutdown_event on exception so child processes exit
gracefully.

Fixes #3059

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
This commit is contained in:
Jash Shah
2026-04-13 07:25:42 -07:00
committed by GitHub
parent df0763a2bc
commit 9bd844a3b9
2 changed files with 51 additions and 45 deletions
+27 -24
View File
@@ -175,33 +175,36 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
interactions_process.start() interactions_process.start()
receive_policy_process.start() receive_policy_process.start()
act_with_policy( try:
cfg=cfg, act_with_policy(
shutdown_event=shutdown_event, cfg=cfg,
parameters_queue=parameters_queue, shutdown_event=shutdown_event,
transitions_queue=transitions_queue, parameters_queue=parameters_queue,
interactions_queue=interactions_queue, transitions_queue=transitions_queue,
) interactions_queue=interactions_queue,
logging.info("[ACTOR] Policy process joined") )
logging.info("[ACTOR] Policy loop finished")
except Exception:
logging.exception("[ACTOR] Unhandled exception in act_with_policy")
shutdown_event.set()
finally:
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
logging.info("[ACTOR] Closing queues") transitions_process.join()
transitions_queue.close() logging.info("[ACTOR] Transitions process joined")
interactions_queue.close() interactions_process.join()
parameters_queue.close() logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
transitions_process.join() transitions_queue.cancel_join_thread()
logging.info("[ACTOR] Transitions process joined") interactions_queue.cancel_join_thread()
interactions_process.join() parameters_queue.cancel_join_thread()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues") logging.info("[ACTOR] Cleanup complete")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
# Core algorithm functions # Core algorithm functions
+24 -21
View File
@@ -218,30 +218,33 @@ def start_learner_threads(
) )
communication_process.start() communication_process.start()
add_actor_information_and_train( try:
cfg=cfg, add_actor_information_and_train(
wandb_logger=wandb_logger, cfg=cfg,
shutdown_event=shutdown_event, wandb_logger=wandb_logger,
transition_queue=transition_queue, shutdown_event=shutdown_event,
interaction_message_queue=interaction_message_queue, transition_queue=transition_queue,
parameters_queue=parameters_queue, interaction_message_queue=interaction_message_queue,
) parameters_queue=parameters_queue,
logging.info("[LEARNER] Training process stopped") )
logging.info("[LEARNER] Training process stopped")
except Exception:
logging.exception("[LEARNER] Unhandled exception in training loop")
shutdown_event.set()
finally:
logging.info("[LEARNER] Closing queues")
transition_queue.close()
interaction_message_queue.close()
parameters_queue.close()
logging.info("[LEARNER] Closing queues") communication_process.join()
transition_queue.close() logging.info("[LEARNER] Communication process joined")
interaction_message_queue.close()
parameters_queue.close()
communication_process.join() transition_queue.cancel_join_thread()
logging.info("[LEARNER] Communication process joined") interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[LEARNER] join queues") logging.info("[LEARNER] Cleanup complete")
transition_queue.cancel_join_thread()
interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[LEARNER] queues closed")
# Core algorithm functions # Core algorithm functions