mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Change HILSerlRobotEnvConfig to inherit from EnvConfig
Added support for hil_serl classifier to be trained with train.py run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier fixes in find_joint_limits, control_robot, end_effector_control_utils
This commit is contained in:
committed by
AdilZouitine
parent
052a4acfc2
commit
d0b7690bc0
@@ -15,12 +15,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pprint import pformat
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
@@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs import parser
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state as utils_load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
save_training_state,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.train_utils import (
|
||||
load_training_state as utils_load_training_state,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
@@ -70,47 +77,39 @@ from lerobot.scripts.server.buffer import (
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
||||
"""
|
||||
Handle the resume logic for training.
|
||||
|
||||
|
||||
If resume is True:
|
||||
- Verifies that a checkpoint exists
|
||||
- Loads the checkpoint configuration
|
||||
- Logs resumption details
|
||||
- Returns the checkpoint configuration
|
||||
|
||||
|
||||
If resume is False:
|
||||
- Checks if an output directory exists (to prevent accidental overwriting)
|
||||
- Returns the original configuration
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
|
||||
|
||||
Returns:
|
||||
TrainPipelineConfig: The updated configuration
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
|
||||
"""
|
||||
out_dir = cfg.output_dir
|
||||
|
||||
|
||||
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
|
||||
if not cfg.resume:
|
||||
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
if os.path.exists(checkpoint_dir):
|
||||
raise RuntimeError(
|
||||
f"Output directory {checkpoint_dir} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
@@ -131,7 +130,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
||||
# Load config using Draccus
|
||||
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
|
||||
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
|
||||
|
||||
|
||||
# Ensure resume flag is set in returned config
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
@@ -143,11 +142,11 @@ def load_training_state(
|
||||
):
|
||||
"""
|
||||
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
optimizers (Optimizer | dict): Optimizers to load state into
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||
"""
|
||||
@@ -156,23 +155,23 @@ def load_training_state(
|
||||
|
||||
# Construct path to the last checkpoint directory
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
|
||||
|
||||
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
|
||||
|
||||
# Load interaction step separately from training_state.pt
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
interaction_step = 0
|
||||
if os.path.exists(training_state_path):
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load training state: {e}")
|
||||
return None, None
|
||||
@@ -181,7 +180,7 @@ def load_training_state(
|
||||
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||
"""
|
||||
Log information about the training process.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
policy (nn.Module): Policy model
|
||||
@@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.policy.online_steps=}")
|
||||
@@ -197,19 +195,15 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize a replay buffer, either empty or from a dataset if resuming.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized replay buffer
|
||||
"""
|
||||
@@ -224,7 +218,7 @@ def initialize_replay_buffer(
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||
|
||||
|
||||
# NOTE: In RL is possible to not have a dataset.
|
||||
repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
@@ -250,13 +244,13 @@ def initialize_offline_replay_buffer(
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize an offline replay buffer from a dataset.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized offline replay buffer
|
||||
"""
|
||||
@@ -314,7 +308,7 @@ def start_learner_threads(
|
||||
) -> None:
|
||||
"""
|
||||
Start the learner threads for training.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
wandb_logger (WandBLogger | None): Logger for metrics
|
||||
@@ -512,17 +506,19 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Initializing policy")
|
||||
# Get checkpoint dir for resuming
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
checkpoint_dir = (
|
||||
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
)
|
||||
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
env_cfg=cfg.env
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
|
||||
clip_grad_norm_value: float = cfg.policy.grad_clip_norm
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
@@ -536,7 +532,7 @@ def add_actor_information_and_train(
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy= policy)
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
batch_size = cfg.batch_size
|
||||
@@ -615,14 +611,10 @@ def add_actor_information_and_train(
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
|
||||
|
||||
# Log interaction messages with WandB if available
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=interaction_message,
|
||||
mode="train",
|
||||
custom_step_key="Interaction step"
|
||||
)
|
||||
wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
@@ -636,7 +628,9 @@ def add_actor_information_and_train(
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
@@ -759,14 +753,10 @@ def add_actor_information_and_train(
|
||||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
|
||||
|
||||
# Log training metrics
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=training_infos,
|
||||
mode="train",
|
||||
custom_step_key="Optimization step"
|
||||
)
|
||||
wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
@@ -795,29 +785,19 @@ def add_actor_information_and_train(
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
|
||||
# Save checkpoint
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
optimization_step,
|
||||
cfg,
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None
|
||||
)
|
||||
|
||||
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {
|
||||
"step": optimization_step,
|
||||
"interaction_step": interaction_step
|
||||
}
|
||||
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
@@ -826,17 +806,13 @@ def add_actor_information_and_train(
|
||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
|
||||
# Save dataset
|
||||
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||
# eg. RL training without demonstrations data
|
||||
# eg. RL training without demonstrations data
|
||||
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
repo_id=repo_id_buffer_save,
|
||||
fps=fps,
|
||||
root=dataset_dir
|
||||
)
|
||||
|
||||
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
|
||||
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||
@@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
@@ -898,19 +872,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
"""
|
||||
Main training function that initializes and runs the training process.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||
"""
|
||||
|
||||
|
||||
cfg.validate()
|
||||
# if cfg.output_dir is None:
|
||||
# raise ValueError("Output directory must be specified in config")
|
||||
|
||||
|
||||
if job_name is None:
|
||||
job_name = cfg.job_name
|
||||
|
||||
|
||||
if job_name is None:
|
||||
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||
|
||||
@@ -920,11 +894,12 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
# Setup WandB logging if enabled
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
|
||||
# Handle resume logic
|
||||
cfg = handle_resume_logic(cfg)
|
||||
|
||||
@@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainPipelineConfig):
|
||||
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
# Use the job_name from the config
|
||||
|
||||
Reference in New Issue
Block a user