mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
fix(optim): enable and resolve mypy type errors (#2683)
* fix(optim): enable and resolve mypy type errors Resolves #1729 build(deps): add mypy as dependency and update pre-commit hook * change build's type annotation
This commit is contained in:
@@ -87,7 +87,7 @@ repos:
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
rev: v1.19.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
|
||||
+4
-4
@@ -141,7 +141,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
@@ -320,9 +320,9 @@ disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.optim.*"
|
||||
ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
|
||||
@@ -35,6 +35,8 @@ def make_optimizer_and_scheduler(
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
if cfg.optimizer is None:
|
||||
raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig")
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -29,6 +30,17 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
# Type alias for parameters accepted by optimizer build() methods.
|
||||
# This matches PyTorch's optimizer signature while also supporting:
|
||||
# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA)
|
||||
# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC)
|
||||
OptimizerParams = (
|
||||
Iterable[torch.nn.Parameter] # From model.parameters()
|
||||
| Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides
|
||||
| dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR
|
||||
| dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -45,13 +57,24 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Args:
|
||||
params: Parameters to optimize. Accepts multiple formats depending on the optimizer:
|
||||
- Iterable[Parameter]: From model.parameters() - standard PyTorch usage
|
||||
- Iterable[dict]: List of param groups with 'params' key and optional
|
||||
'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies)
|
||||
- dict[str, Parameter]: From dict(model.named_parameters()) for optimizers
|
||||
that apply differential learning rates by parameter name (e.g., XVLA)
|
||||
- dict[str, Iterable]: For multi-optimizer configs where each key maps to
|
||||
a separate optimizer's parameters (e.g., SAC with actor/critic/temperature)
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
@@ -67,7 +90,7 @@ class AdamConfig(OptimizerConfig):
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
@@ -82,7 +105,7 @@ class AdamWConfig(OptimizerConfig):
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
@@ -98,7 +121,7 @@ class SGDConfig(OptimizerConfig):
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
@@ -139,21 +162,19 @@ class XVLAAdamWConfig(OptimizerConfig):
|
||||
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
||||
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Build AdamW optimizer with differential learning rates.
|
||||
|
||||
Expects `named_parameters()` as input (dict of name -> param).
|
||||
Applies:
|
||||
- lr * 0.1 for all VLM-related parameters
|
||||
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
|
||||
- full lr for all other parameters
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameter names to parameters (from named_parameters())
|
||||
params: Must be a dict[str, Parameter] from dict(model.named_parameters())
|
||||
or equivalent.
|
||||
|
||||
Returns:
|
||||
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
||||
|
||||
Raises:
|
||||
AssertionError: If params is not a dict (e.g., from model.parameters())
|
||||
"""
|
||||
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
||||
|
||||
@@ -174,7 +195,7 @@ class XVLAAdamWConfig(OptimizerConfig):
|
||||
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
||||
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
||||
|
||||
param_groups = [
|
||||
param_groups: list[dict[str, Any]] = [
|
||||
{
|
||||
"params": vlm_group,
|
||||
"lr": self.lr * 0.1,
|
||||
@@ -224,19 +245,25 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names
|
||||
to iterables of parameters. The keys should match the keys in optimizer_groups.
|
||||
Typically from policies that need separate optimizers (e.g., SAC with
|
||||
actor/critic/temperature).
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
|
||||
Raises:
|
||||
AssertionError: If params is not a dict
|
||||
"""
|
||||
assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs."
|
||||
optimizers = {}
|
||||
|
||||
for name, params in params_dict.items():
|
||||
for name, group_params in params.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
@@ -248,7 +275,7 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int
|
||||
num_warmup_steps: int | None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user