mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
[HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
committed by
AdilZouitine
parent
38f5fa4523
commit
db78fee9de
@@ -15,15 +15,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from threading import Lock, Thread
|
||||
import signal
|
||||
from threading import Event
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
@@ -52,19 +55,19 @@ from lerobot.common.utils.utils import (
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_transition_to_device,
|
||||
move_state_dict_to_device,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
bytes_to_python_object,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
if not cfg.resume:
|
||||
@@ -195,67 +198,96 @@ def get_observation_features(
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.learner == "threads"
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
out_dir: str,
|
||||
shutdown_event: any, # Event,
|
||||
) -> None:
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
# Create multiprocessing queues
|
||||
transition_queue = Queue()
|
||||
interaction_message_queue = Queue()
|
||||
parameters_queue = Queue()
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
concurrency_entity = None
|
||||
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
communication_process = concurrency_entity(
|
||||
target=start_learner_server,
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
parameters_queue,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
communication_process.start()
|
||||
|
||||
transition_thread.start()
|
||||
add_actor_information_and_train(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
parameters_queue,
|
||||
)
|
||||
logging.info("[LEARNER] Training process stopped")
|
||||
|
||||
logging.info("[LEARNER] Closing queues")
|
||||
transition_queue.close()
|
||||
interaction_message_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
communication_process.join()
|
||||
logging.info("[LEARNER] Communication process joined")
|
||||
|
||||
logging.info("[LEARNER] join queues")
|
||||
transition_queue.cancel_join_thread()
|
||||
interaction_message_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[LEARNER] queues closed")
|
||||
|
||||
|
||||
def start_learner_server(
|
||||
parameters_queue: Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
cfg: DictConfig,
|
||||
):
|
||||
if not use_threads(cfg):
|
||||
# We need init logging for MP separataly
|
||||
init_logging()
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
# Return back for MP
|
||||
setup_process_handlers(False)
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event,
|
||||
policy,
|
||||
policy_lock,
|
||||
parameters_queue,
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
)
|
||||
server = start_learner_server(service, host, port)
|
||||
|
||||
shutdown_event.wait()
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
transition_thread.join()
|
||||
logging.info("[LEARNER] Transition thread stopped")
|
||||
|
||||
|
||||
def start_learner_server(
|
||||
service: learner_service.LearnerService,
|
||||
host="0.0.0.0",
|
||||
port=50051,
|
||||
) -> grpc.server:
|
||||
server = grpc.server(
|
||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||
options=[
|
||||
@@ -263,15 +295,23 @@ def start_learner_server(
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
],
|
||||
)
|
||||
|
||||
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||
service,
|
||||
server,
|
||||
)
|
||||
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
logging.info("[LEARNER] gRPC server started")
|
||||
|
||||
return server
|
||||
shutdown_event.wait()
|
||||
logging.info("[LEARNER] Stopping gRPC server...")
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
|
||||
def check_nan_in_transition(
|
||||
@@ -287,19 +327,21 @@ def check_nan_in_transition(
|
||||
logging.error("actions contains NaN values")
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
||||
state_bytes = state_to_bytes(state_dict)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
out_dir: str,
|
||||
shutdown_event: any, # Event,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
parameters_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
@@ -322,17 +364,73 @@ def add_actor_information_and_train(
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
shutdown_event (Event | None): Event to signal shutdown.
|
||||
out_dir (str): The output directory for storing training checkpoints and logs.
|
||||
shutdown_event (Event): Event to signal shutdown.
|
||||
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
logging.info("Initializing policy")
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
@@ -345,33 +443,39 @@ def add_actor_information_and_train(
|
||||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
saved_data = False
|
||||
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
break
|
||||
|
||||
while not transition_queue.empty():
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
while not transition_queue.empty() and not shutdown_event.is_set():
|
||||
transition_list = transition_queue.get()
|
||||
transition_list = bytes_to_transitions(transition_list)
|
||||
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(
|
||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||
)
|
||||
# logging.info(f"Interaction message: {interaction_message}")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
@@ -392,19 +496,18 @@ def add_actor_information_and_train(
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
@@ -427,46 +530,51 @@ def add_actor_information_and_train(
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
if (
|
||||
time.time() - last_time_policy_pushed
|
||||
> cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
):
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
@@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print(
|
||||
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
|
||||
)
|
||||
shutdown_event.set()
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] train_cli finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
logging.info("[LEARNER] main finished")
|
||||
|
||||
Reference in New Issue
Block a user