mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -17,15 +17,8 @@
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from pprint import pformat
|
||||
|
||||
import grpc
|
||||
|
||||
@@ -37,6 +30,11 @@ from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
@@ -55,18 +53,17 @@ from lerobot.common.utils.utils import (
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_transition_to_device,
|
||||
move_state_dict_to_device,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
bytes_to_python_object,
|
||||
bytes_to_transitions,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
state_to_bytes,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
@@ -81,13 +78,9 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(
|
||||
f"No model checkpoint found in {checkpoint_dir} for resume=True"
|
||||
)
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
@@ -136,9 +129,7 @@ def load_training_state(
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -210,22 +201,15 @@ def initialize_offline_replay_buffer(
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if (
|
||||
policy.config.vision_encoder_name is None
|
||||
or not policy.config.freeze_vision_encoder
|
||||
):
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
@@ -452,9 +436,7 @@ def add_actor_information_and_train(
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
@@ -469,9 +451,7 @@ def add_actor_information_and_train(
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
@@ -483,9 +463,7 @@ def add_actor_information_and_train(
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
@@ -502,12 +480,8 @@ def add_actor_information_and_train(
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = (
|
||||
resume_optimization_step if resume_optimization_step is not None else 0
|
||||
)
|
||||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
# Extract variables from cfg
|
||||
online_step_before_learning = cfg.training.online_step_before_learning
|
||||
@@ -519,9 +493,7 @@ def add_actor_information_and_train(
|
||||
device = cfg.device
|
||||
storage_device = cfg.training.storage_device
|
||||
policy_update_freq = cfg.training.policy_update_freq
|
||||
policy_parameters_push_frequency = (
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
)
|
||||
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
save_checkpoint = cfg.training.save_checkpoint
|
||||
online_steps = cfg.training.online_steps
|
||||
|
||||
@@ -544,9 +516,9 @@ def add_actor_information_and_train(
|
||||
continue
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if cfg.dataset_repo_id is not None and transition.get(
|
||||
"complementary_info", {}
|
||||
).get("is_intervention"):
|
||||
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
@@ -556,9 +528,7 @@ def add_actor_information_and_train(
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(
|
||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||
)
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
@@ -579,9 +549,7 @@ def add_actor_information_and_train(
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
@@ -619,9 +587,7 @@ def add_actor_information_and_train(
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
@@ -697,23 +663,15 @@ def add_actor_information_and_train(
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(
|
||||
offline_replay_buffer
|
||||
)
|
||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(
|
||||
d=training_infos, mode="train", custom_step_key="Optimization step"
|
||||
)
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (
|
||||
time_for_one_optimization_step + 1e-9
|
||||
)
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(
|
||||
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
|
||||
)
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
@@ -728,16 +686,12 @@ def add_actor_information_and_train(
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if save_checkpoint and (
|
||||
optimization_step % save_freq == 0 or optimization_step == online_steps
|
||||
):
|
||||
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"]
|
||||
if interaction_message is not None
|
||||
else 0
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
@@ -755,9 +709,7 @@ def add_actor_information_and_train(
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
@@ -809,9 +761,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
|
||||
Reference in New Issue
Block a user