mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
[WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200. - Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations. - Improved input and output feature management in `SACConfig`. - Refactored `actor_server` and `learner_server` to access configuration properties directly. - Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
@@ -48,6 +48,7 @@ from lerobot.common.utils.train_utils import (
|
||||
load_training_state as utils_load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
save_training_state,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
@@ -160,13 +161,14 @@ def load_training_state(
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
# The function returns (step, updated_optimizer, scheduler)
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
|
||||
# For interaction step, we still need to load the training_state.pt file
|
||||
# Load interaction step separately from training_state.pt
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
interaction_step = 0
|
||||
if os.path.exists(training_state_path):
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
@@ -222,16 +224,20 @@ def initialize_replay_buffer(
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||
|
||||
# NOTE: In RL is possible to not have a dataset.
|
||||
repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
repo_id = cfg.dataset.dataset_repo_id
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
repo_id=repo_id,
|
||||
root=dataset_path,
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.policy.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
@@ -298,7 +304,7 @@ def get_observation_features(
|
||||
|
||||
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency["learner"] == "threads"
|
||||
return cfg.policy.concurrency.learner == "threads"
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
@@ -388,7 +394,7 @@ def start_learner_server(
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
|
||||
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue=transition_queue,
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
)
|
||||
@@ -406,8 +412,8 @@ def start_learner_server(
|
||||
server,
|
||||
)
|
||||
|
||||
host = cfg.policy.actor_learner_config["learner_host"]
|
||||
port = cfg.policy.actor_learner_config["learner_port"]
|
||||
host = cfg.policy.actor_learner_config.learner_host
|
||||
port = cfg.policy.actor_learner_config.learner_port
|
||||
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
@@ -509,7 +515,6 @@ def add_actor_information_and_train(
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||
|
||||
# TODO(Adil): This don't work anymore !
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
@@ -575,8 +580,8 @@ def add_actor_information_and_train(
|
||||
device = cfg.policy.device
|
||||
storage_device = cfg.policy.storage_device
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
|
||||
save_checkpoint = cfg.save_checkpoint
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
|
||||
while True:
|
||||
@@ -598,7 +603,7 @@ def add_actor_information_and_train(
|
||||
continue
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
@@ -618,9 +623,6 @@ def add_actor_information_and_train(
|
||||
mode="train",
|
||||
custom_step_key="Interaction step"
|
||||
)
|
||||
else:
|
||||
# Log to console if no WandB logger
|
||||
logging.info(f"Interaction: {interaction_message}")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
@@ -765,9 +767,6 @@ def add_actor_information_and_train(
|
||||
mode="train",
|
||||
custom_step_key="Optimization step"
|
||||
)
|
||||
else:
|
||||
# Log to console if no WandB logger
|
||||
logging.info(f"Training: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
@@ -789,7 +788,7 @@ def add_actor_information_and_train(
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
@@ -810,6 +809,15 @@ def add_actor_information_and_train(
|
||||
scheduler=None
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {
|
||||
"step": optimization_step,
|
||||
"interaction_step": interaction_step
|
||||
}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
@@ -820,8 +828,11 @@ def add_actor_information_and_train(
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# Save dataset
|
||||
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||
# eg. RL training without demonstrations data
|
||||
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
dataset_repo_id,
|
||||
repo_id=repo_id_buffer_save,
|
||||
fps=fps,
|
||||
root=dataset_dir
|
||||
)
|
||||
@@ -892,8 +903,10 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||
"""
|
||||
if cfg.output_dir is None:
|
||||
raise ValueError("Output directory must be specified in config")
|
||||
|
||||
cfg.validate()
|
||||
# if cfg.output_dir is None:
|
||||
# raise ValueError("Output directory must be specified in config")
|
||||
|
||||
if job_name is None:
|
||||
job_name = cfg.job_name
|
||||
|
||||
Reference in New Issue
Block a user