[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by AdilZouitine
parent 2945bbb221
commit 7c05755823
123 changed files with 1161 additions and 3425 deletions
+39 -89
View File
@@ -17,15 +17,8 @@
import logging
import shutil
import time
from pprint import pformat
from concurrent.futures import ThreadPoolExecutor
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from lerobot.scripts.server.utils import setup_process_handlers
from pprint import pformat
import grpc
@@ -37,6 +30,11 @@ from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
@@ -55,18 +53,17 @@ from lerobot.common.utils.utils import (
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_transition_to_device,
move_state_dict_to_device,
bytes_to_transitions,
state_to_bytes,
bytes_to_python_object,
bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes,
)
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.utils import setup_process_handlers
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
@@ -81,13 +78,9 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"Resume=True detected, resuming previous run",
@@ -136,9 +129,7 @@ def load_training_state(
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@@ -210,22 +201,15 @@ def initialize_offline_replay_buffer(
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if (
policy.config.vision_encoder_name is None
or not policy.config.freeze_vision_encoder
):
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
return observation_features, next_observation_features
@@ -452,9 +436,7 @@ def add_actor_information_and_train(
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
@@ -469,9 +451,7 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
log_training_info(cfg, out_dir, policy)
@@ -483,9 +463,7 @@ def add_actor_information_and_train(
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
@@ -502,12 +480,8 @@ def add_actor_information_and_train(
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = (
resume_optimization_step if resume_optimization_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
# Extract variables from cfg
online_step_before_learning = cfg.training.online_step_before_learning
@@ -519,9 +493,7 @@ def add_actor_information_and_train(
device = cfg.device
storage_device = cfg.training.storage_device
policy_update_freq = cfg.training.policy_update_freq
policy_parameters_push_frequency = (
cfg.actor_learner_config.policy_parameters_push_frequency
)
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
save_checkpoint = cfg.training.save_checkpoint
online_steps = cfg.training.online_steps
@@ -544,9 +516,9 @@ def add_actor_information_and_train(
continue
replay_buffer.add(**transition)
if cfg.dataset_repo_id is not None and transition.get(
"complementary_info", {}
).get("is_intervention"):
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
):
offline_replay_buffer.add(**transition)
logging.debug("[LEARNER] Received transitions")
@@ -556,9 +528,7 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logging.debug("[LEARNER] Received interactions")
@@ -579,9 +549,7 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
@@ -619,9 +587,7 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
@@ -697,23 +663,15 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(
offline_replay_buffer
)
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step
logger.log_dict(
d=training_infos, mode="train", custom_step_key="Optimization step"
)
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (
time_for_one_optimization_step + 1e-9
)
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
{
@@ -728,16 +686,12 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and (
optimization_step % save_freq == 0 or optimization_step == online_steps
):
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"]
if interaction_message is not None
else 0
interaction_message["Interaction step"] if interaction_message is not None else 0
)
logger.save_checkpoint(
optimization_step,
@@ -755,9 +709,7 @@ def add_actor_information_and_train(
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
)
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
@@ -809,9 +761,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,