fix(rl): ensure queue and process cleanup on abnormal exit (#3063)

Wrap the main execution in actor_cli and start_learner_threads with
try/finally so that queues are closed and processes are joined even
when an unhandled exception occurs. Previously, exceptions in
act_with_policy or add_actor_information_and_train would skip all
cleanup code, leaking GPU/CPU resources.

Also sets the shutdown_event on exception so child processes exit
gracefully.

Fixes #3059

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
This commit is contained in:
Jash Shah
2026-04-13 07:25:42 -07:00
committed by GitHub
parent df0763a2bc
commit 9bd844a3b9
2 changed files with 51 additions and 45 deletions
+7 -4
View File
@@ -175,6 +175,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
interactions_process.start() interactions_process.start()
receive_policy_process.start() receive_policy_process.start()
try:
act_with_policy( act_with_policy(
cfg=cfg, cfg=cfg,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
@@ -182,8 +183,11 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
transitions_queue=transitions_queue, transitions_queue=transitions_queue,
interactions_queue=interactions_queue, interactions_queue=interactions_queue,
) )
logging.info("[ACTOR] Policy process joined") logging.info("[ACTOR] Policy loop finished")
except Exception:
logging.exception("[ACTOR] Unhandled exception in act_with_policy")
shutdown_event.set()
finally:
logging.info("[ACTOR] Closing queues") logging.info("[ACTOR] Closing queues")
transitions_queue.close() transitions_queue.close()
interactions_queue.close() interactions_queue.close()
@@ -196,12 +200,11 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
receive_policy_process.join() receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined") logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread() transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread() interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread() parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed") logging.info("[ACTOR] Cleanup complete")
# Core algorithm functions # Core algorithm functions
+6 -3
View File
@@ -218,6 +218,7 @@ def start_learner_threads(
) )
communication_process.start() communication_process.start()
try:
add_actor_information_and_train( add_actor_information_and_train(
cfg=cfg, cfg=cfg,
wandb_logger=wandb_logger, wandb_logger=wandb_logger,
@@ -227,7 +228,10 @@ def start_learner_threads(
parameters_queue=parameters_queue, parameters_queue=parameters_queue,
) )
logging.info("[LEARNER] Training process stopped") logging.info("[LEARNER] Training process stopped")
except Exception:
logging.exception("[LEARNER] Unhandled exception in training loop")
shutdown_event.set()
finally:
logging.info("[LEARNER] Closing queues") logging.info("[LEARNER] Closing queues")
transition_queue.close() transition_queue.close()
interaction_message_queue.close() interaction_message_queue.close()
@@ -236,12 +240,11 @@ def start_learner_threads(
communication_process.join() communication_process.join()
logging.info("[LEARNER] Communication process joined") logging.info("[LEARNER] Communication process joined")
logging.info("[LEARNER] join queues")
transition_queue.cancel_join_thread() transition_queue.cancel_join_thread()
interaction_message_queue.cancel_join_thread() interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread() parameters_queue.cancel_join_thread()
logging.info("[LEARNER] queues closed") logging.info("[LEARNER] Cleanup complete")
# Core algorithm functions # Core algorithm functions