mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
[WIP] Non functional yet
Add ManiSkill environment configuration and wrappers - Introduced `VideoRecordConfig` for video recording settings. - Added `ManiskillEnvConfig` to encapsulate environment-specific configurations. - Implemented various wrappers for the ManiSkill environment, including observation and action scaling. - Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing. - Updated the `actor_server` and `learner_server` to utilize the new configuration structure. - Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
@@ -19,40 +19,45 @@ import shutil
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pprint import pformat
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs import parser
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state as utils_load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_global_random_state,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
@@ -64,102 +69,167 @@ from lerobot.scripts.server.buffer import (
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
||||
"""
|
||||
Handle the resume logic for training.
|
||||
|
||||
If resume is True:
|
||||
- Verifies that a checkpoint exists
|
||||
- Loads the checkpoint configuration
|
||||
- Logs resumption details
|
||||
- Returns the checkpoint configuration
|
||||
|
||||
If resume is False:
|
||||
- Checks if an output directory exists (to prevent accidental overwriting)
|
||||
- Returns the original configuration
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
|
||||
Returns:
|
||||
TrainPipelineConfig: The updated configuration
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
|
||||
"""
|
||||
out_dir = cfg.output_dir
|
||||
|
||||
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
|
||||
if not cfg.resume:
|
||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
if os.path.exists(checkpoint_dir):
|
||||
raise RuntimeError(
|
||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||
f"Output directory {checkpoint_dir} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
# Case 2: Resuming training
|
||||
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
# Log that we found a valid checkpoint and are resuming
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
"Valid checkpoint found: resume=True detected, resuming previous run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||
"Checkpoint configuration takes precedence."
|
||||
)
|
||||
|
||||
# Load config using Draccus
|
||||
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
|
||||
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
|
||||
|
||||
# Ensure resume flag is set in returned config
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
||||
|
||||
def load_training_state(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
optimizers: Optimizer | dict,
|
||||
cfg: TrainPipelineConfig,
|
||||
optimizers: Optimizer | dict[str, Optimizer],
|
||||
):
|
||||
"""
|
||||
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
optimizers (Optimizer | dict): Optimizers to load state into
|
||||
|
||||
Returns:
|
||||
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||
"""
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
||||
)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], training_state["interaction_step"]
|
||||
# Construct path to the last checkpoint directory
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
|
||||
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
# The function returns (step, updated_optimizer, scheduler)
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
|
||||
# For interaction step, we still need to load the training_state.pt file
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load training state: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||
"""
|
||||
Log information about the training process.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
policy (nn.Module): Policy model
|
||||
"""
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{cfg.policy.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device: str
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize a replay buffer, either empty or from a dataset if resuming.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized replay buffer
|
||||
"""
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
capacity=cfg.policy.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
repo_id=cfg.dataset.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset",
|
||||
root=dataset_path,
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
capacity=cfg.policy.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
optimize_memory=True,
|
||||
@@ -167,33 +237,45 @@ def initialize_replay_buffer(
|
||||
|
||||
|
||||
def initialize_offline_replay_buffer(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize an offline replay buffer from a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized offline replay buffer
|
||||
"""
|
||||
if not cfg.resume:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if cfg.resume:
|
||||
else:
|
||||
logging.info("load offline dataset")
|
||||
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
offline_dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
repo_id=cfg.dataset.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
root=dataset_offline_path,
|
||||
)
|
||||
|
||||
logging.info("Convert to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.training.offline_buffer_capacity,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
)
|
||||
return offline_replay_buffer
|
||||
|
||||
@@ -215,16 +297,23 @@ def get_observation_features(
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.learner == "threads"
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency["learner"] == "threads"
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
out_dir: str,
|
||||
cfg: TrainPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
) -> None:
|
||||
"""
|
||||
Start the learner threads for training.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
wandb_logger (WandBLogger | None): Logger for metrics
|
||||
shutdown_event: Event to signal shutdown
|
||||
"""
|
||||
# Create multiprocessing queues
|
||||
transition_queue = Queue()
|
||||
interaction_message_queue = Queue()
|
||||
@@ -255,13 +344,12 @@ def start_learner_threads(
|
||||
communication_process.start()
|
||||
|
||||
add_actor_information_and_train(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
parameters_queue,
|
||||
cfg=cfg,
|
||||
wandb_logger=wandb_logger,
|
||||
shutdown_event=shutdown_event,
|
||||
transition_queue=transition_queue,
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
parameters_queue=parameters_queue,
|
||||
)
|
||||
logging.info("[LEARNER] Training process stopped")
|
||||
|
||||
@@ -286,7 +374,7 @@ def start_learner_server(
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
cfg: DictConfig,
|
||||
cfg: TrainPipelineConfig,
|
||||
):
|
||||
if not use_threads(cfg):
|
||||
# We need init logging for MP separataly
|
||||
@@ -298,11 +386,11 @@ def start_learner_server(
|
||||
setup_process_handlers(False)
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
|
||||
transition_queue=transition_queue,
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
)
|
||||
|
||||
server = grpc.server(
|
||||
@@ -318,8 +406,8 @@ def start_learner_server(
|
||||
server,
|
||||
)
|
||||
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
host = cfg.policy.actor_learner_config["learner_host"]
|
||||
port = cfg.policy.actor_learner_config["learner_port"]
|
||||
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
@@ -385,9 +473,8 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
logger: Logger,
|
||||
out_dir: str,
|
||||
cfg: TrainPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
@@ -405,69 +492,60 @@ def add_actor_information_and_train(
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
**NOTE:**
|
||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
||||
It should ideally be split into smaller functions in the future.
|
||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
out_dir (str): The output directory for storing training checkpoints and logs.
|
||||
cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
|
||||
wandb_logger (WandBLogger | None): Logger for tracking training progress.
|
||||
shutdown_event (Event): Event to signal shutdown.
|
||||
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
|
||||
logging.info("Initializing policy")
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
# Get checkpoint dir for resuming
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||
|
||||
# TODO(Adil): This don't work anymore !
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
env_cfg=cfg.env
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
clip_grad_norm_value = cfg.training.grad_clip_norm
|
||||
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
log_training_info(cfg=cfg, policy= policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
batch_size = cfg.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
if cfg.dataset is not None:
|
||||
active_action_dims = None
|
||||
# TODO: FIX THIS
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
logger=logger,
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
active_action_dims=active_action_dims,
|
||||
@@ -484,18 +562,22 @@ def add_actor_information_and_train(
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
# Extract variables from cfg
|
||||
online_step_before_learning = cfg.training.online_step_before_learning
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
dataset_repo_id = cfg.dataset_repo_id
|
||||
fps = cfg.fps
|
||||
log_freq = cfg.training.log_freq
|
||||
save_freq = cfg.training.save_freq
|
||||
device = cfg.device
|
||||
storage_device = cfg.training.storage_device
|
||||
policy_update_freq = cfg.training.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
save_checkpoint = cfg.training.save_checkpoint
|
||||
online_steps = cfg.training.online_steps
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
device = cfg.policy.device
|
||||
storage_device = cfg.policy.storage_device
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
|
||||
save_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
@@ -516,7 +598,7 @@ def add_actor_information_and_train(
|
||||
continue
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
@@ -528,7 +610,17 @@ def add_actor_information_and_train(
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
# Log interaction messages with WandB if available
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=interaction_message,
|
||||
mode="train",
|
||||
custom_step_key="Interaction step"
|
||||
)
|
||||
else:
|
||||
# Log to console if no WandB logger
|
||||
logging.info(f"Interaction: {interaction_message}")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
@@ -538,11 +630,11 @@ def add_actor_information_and_train(
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
@@ -552,7 +644,7 @@ def add_actor_information_and_train(
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -568,15 +660,15 @@ def add_actor_information_and_train(
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
@@ -590,7 +682,7 @@ def add_actor_information_and_train(
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -606,7 +698,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
@@ -627,7 +719,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.actor.parameters_to_optimize, clip_grad_norm_value
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["actor"].step()
|
||||
@@ -645,7 +737,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[policy.log_alpha], clip_grad_norm_value
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
@@ -655,7 +747,7 @@ def add_actor_information_and_train(
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
policy.update_target_networks()
|
||||
@@ -665,22 +757,33 @@ def add_actor_information_and_train(
|
||||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
# Log training metrics
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=training_infos,
|
||||
mode="train",
|
||||
custom_step_key="Optimization step"
|
||||
)
|
||||
else:
|
||||
# Log to console if no WandB logger
|
||||
logging.info(f"Training: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||
"Optimization step": optimization_step,
|
||||
},
|
||||
mode="train",
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
# Log optimization frequency
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
{
|
||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||
"Optimization step": optimization_step,
|
||||
},
|
||||
mode="train",
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
@@ -693,35 +796,45 @@ def add_actor_information_and_train(
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
policy,
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
# Save checkpoint
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
optimization_step,
|
||||
cfg,
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None,
|
||||
identifier=step_identifier,
|
||||
interaction_step=interaction_step,
|
||||
scheduler=None
|
||||
)
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = logger.log_dir / "dataset"
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
|
||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# Save dataset
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
dataset_repo_id,
|
||||
fps=fps,
|
||||
root=dataset_dir
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||
shutil.rmtree(dataset_offline_dir)
|
||||
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id,
|
||||
fps=cfg.fps,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
cfg.dataset.dataset_repo_id,
|
||||
fps=cfg.env.fps,
|
||||
root=dataset_offline_dir,
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
@@ -756,12 +869,12 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
@@ -771,19 +884,38 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
"""
|
||||
Main training function that initializes and runs the training process.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||
"""
|
||||
if cfg.output_dir is None:
|
||||
raise ValueError("Output directory must be specified in config")
|
||||
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
job_name = cfg.job_name
|
||||
|
||||
if job_name is None:
|
||||
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
cfg = handle_resume_logic(cfg, out_dir)
|
||||
# Setup WandB logging if enabled
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
# Handle resume logic
|
||||
cfg = handle_resume_logic(cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
set_seed(seed=cfg.seed)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -791,24 +923,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
cfg=cfg,
|
||||
wandb_logger=wandb_logger,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainPipelineConfig):
|
||||
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
# Use the job_name from the config
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
job_name=cfg.job_name,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] train_cli finished")
|
||||
@@ -816,5 +947,4 @@ def train_cli(cfg: dict):
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
logging.info("[LEARNER] main finished")
|
||||
|
||||
Reference in New Issue
Block a user