Change HILSerlRobotEnvConfig to inherit from EnvConfig

Added support for hil_serl classifier to be trained with train.py
run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier
fixes in find_joint_limits, control_robot, end_effector_control_utils
This commit is contained in:
Michel Aractingi
2025-03-27 10:23:14 +01:00
committed by AdilZouitine
parent 052a4acfc2
commit d0b7690bc0
13 changed files with 389 additions and 340 deletions
+73 -98
View File
@@ -15,12 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
import os
from pathlib import Path
from pprint import pformat
import draccus
import grpc
@@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore
import torch
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs import parser
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
save_training_state,
update_last_checkpoint,
)
from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_logging,
)
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
@@ -70,47 +77,39 @@ from lerobot.scripts.server.buffer import (
state_to_bytes,
)
from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
"""
Handle the resume logic for training.
If resume is True:
- Verifies that a checkpoint exists
- Loads the checkpoint configuration
- Logs resumption details
- Returns the checkpoint configuration
If resume is False:
- Checks if an output directory exists (to prevent accidental overwriting)
- Returns the original configuration
Args:
cfg (TrainPipelineConfig): The training configuration
Returns:
TrainPipelineConfig: The updated configuration
Raises:
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
"""
out_dir = cfg.output_dir
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
if not cfg.resume:
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir):
raise RuntimeError(
f"Output directory {checkpoint_dir} already exists. "
"Use `resume=true` to resume training."
f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
)
return cfg
@@ -131,7 +130,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
# Load config using Draccus
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
# Ensure resume flag is set in returned config
checkpoint_cfg.resume = True
return checkpoint_cfg
@@ -143,11 +142,11 @@ def load_training_state(
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Args:
cfg (TrainPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
"""
@@ -156,23 +155,23 @@ def load_training_state(
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
logging.info(f"Loading training state from {checkpoint_dir}")
try:
# Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
except Exception as e:
logging.error(f"Failed to load training state: {e}")
return None, None
@@ -181,7 +180,7 @@ def load_training_state(
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
"""
Log information about the training process.
Args:
cfg (TrainPipelineConfig): Training configuration
policy (nn.Module): Policy model
@@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.policy.online_steps=}")
@@ -197,19 +195,15 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(
cfg: TrainPipelineConfig,
device: str,
storage_device: str
) -> ReplayBuffer:
def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
"""
Initialize a replay buffer, either empty or from a dataset if resuming.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
Returns:
ReplayBuffer: Initialized replay buffer
"""
@@ -224,7 +218,7 @@ def initialize_replay_buffer(
logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
@@ -250,13 +244,13 @@ def initialize_offline_replay_buffer(
) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
"""
@@ -314,7 +308,7 @@ def start_learner_threads(
) -> None:
"""
Start the learner threads for training.
Args:
cfg (TrainPipelineConfig): Training configuration
wandb_logger (WandBLogger | None): Logger for metrics
@@ -512,17 +506,19 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
# Get checkpoint dir for resuming
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
checkpoint_dir = (
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
)
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env
env_cfg=cfg.env,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
clip_grad_norm_value: float = cfg.policy.grad_clip_norm
# compile policy
policy = torch.compile(policy)
@@ -536,7 +532,7 @@ def add_actor_information_and_train(
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy= policy)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.batch_size
@@ -615,14 +611,10 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
# Log interaction messages with WandB if available
if wandb_logger:
wandb_logger.log_dict(
d=interaction_message,
mode="train",
custom_step_key="Interaction step"
)
wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
logging.debug("[LEARNER] Received interactions")
@@ -636,7 +628,9 @@ def add_actor_information_and_train(
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
@@ -759,14 +753,10 @@ def add_actor_information_and_train(
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step
# Log training metrics
if wandb_logger:
wandb_logger.log_dict(
d=training_infos,
mode="train",
custom_step_key="Optimization step"
)
wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@@ -795,29 +785,19 @@ def add_actor_information_and_train(
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
save_checkpoint(
checkpoint_dir,
optimization_step,
cfg,
policy,
optimizers,
scheduler=None
)
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
training_state = {"step": optimization_step, "interaction_step": interaction_step}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@@ -826,17 +806,13 @@ def add_actor_information_and_train(
dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
@@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
@@ -898,19 +872,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
"""
Main training function that initializes and runs the training process.
Args:
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
cfg.validate()
# if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config")
if job_name is None:
job_name = cfg.job_name
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
@@ -920,11 +894,12 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
# Handle resume logic
cfg = handle_resume_logic(cfg)
@@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
@parser.wrap()
def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
# Use the job_name from the config