[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-28 17:20:38 +00:00
parent 808cf63221
commit c05e4835d0
16 changed files with 93 additions and 91 deletions
-1
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from collections import deque
import gymnasium as gym
+17 -14
View File
@@ -99,36 +99,37 @@ class SGDConfig(OptimizerConfig):
@dataclass
class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args:
lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm
"""
lr: float = 1e-3
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
Returns:
Dictionary mapping parameter group names to their optimizers
"""
optimizers = {}
for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = {
"lr": group_config.get("lr", self.lr),
@@ -136,15 +137,17 @@ class MultiAdamConfig(OptimizerConfig):
"eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
@@ -173,11 +176,11 @@ def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from.
Returns:
The updated optimizer(s) with loaded state.
"""
@@ -201,7 +204,7 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
# Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
+1 -1
View File
@@ -24,11 +24,11 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -1,10 +1,9 @@
from dataclasses import dataclass, field
from typing import Dict, List
from dataclasses import dataclass
from typing import List
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
+15 -7
View File
@@ -82,8 +82,10 @@ def create_stats_buffers(
if stats and key in stats:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
raise ValueError(
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
)
if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
@@ -96,12 +98,16 @@ def create_stats_buffers(
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
raise ValueError(
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
)
if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
@@ -110,7 +116,9 @@ def create_stats_buffers(
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["min"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
)
stats_buffers[key] = buffer
return stats_buffers
@@ -19,7 +19,7 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import NormalizationMode
@dataclass
@@ -897,7 +897,6 @@ if __name__ == "__main__":
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
import draccus
from lerobot.configs import parser
+5 -7
View File
@@ -115,11 +115,13 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
def log_dict(
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
):
if mode not in {"train", "eval"}:
raise ValueError(mode)
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
@@ -142,10 +144,7 @@ class WandBLogger:
continue
# Do not log the custom step key itself.
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if custom_step_key is not None:
@@ -160,7 +159,6 @@ class WandBLogger:
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
raise ValueError(mode)