mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
fix(optim): enable and resolve mypy type errors (#2683)
* fix(optim): enable and resolve mypy type errors Resolves #1729 build(deps): add mypy as dependency and update pre-commit hook * change build's type annotation
This commit is contained in:
@@ -87,7 +87,7 @@ repos:
|
|||||||
# TODO(Steven): Uncomment when ready to use
|
# TODO(Steven): Uncomment when ready to use
|
||||||
##### Static Analysis & Typing #####
|
##### Static Analysis & Typing #####
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.18.2
|
rev: v1.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
args: [--config-file=pyproject.toml]
|
args: [--config-file=pyproject.toml]
|
||||||
|
|||||||
+4
-4
@@ -141,7 +141,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
|||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||||
|
|
||||||
@@ -320,9 +320,9 @@ disallow_untyped_defs = true
|
|||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
# module = "lerobot.optim.*"
|
module = "lerobot.optim.*"
|
||||||
# ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = "lerobot.model.*"
|
module = "lerobot.model.*"
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ def make_optimizer_and_scheduler(
|
|||||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||||
"""
|
"""
|
||||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||||
|
if cfg.optimizer is None:
|
||||||
|
raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig")
|
||||||
optimizer = cfg.optimizer.build(params)
|
optimizer = cfg.optimizer.build(params)
|
||||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -29,6 +30,17 @@ from lerobot.utils.constants import (
|
|||||||
)
|
)
|
||||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||||
|
|
||||||
|
# Type alias for parameters accepted by optimizer build() methods.
|
||||||
|
# This matches PyTorch's optimizer signature while also supporting:
|
||||||
|
# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA)
|
||||||
|
# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC)
|
||||||
|
OptimizerParams = (
|
||||||
|
Iterable[torch.nn.Parameter] # From model.parameters()
|
||||||
|
| Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides
|
||||||
|
| dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR
|
||||||
|
| dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
@@ -45,13 +57,24 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
return "adam"
|
return "adam"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||||
"""
|
"""
|
||||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||||
|
|
||||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||||
For example, you can have one optimizer for the policy and another one for the value function
|
For example, you can have one optimizer for the policy and another one for the value function
|
||||||
in reinforcement learning settings.
|
in reinforcement learning settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Parameters to optimize. Accepts multiple formats depending on the optimizer:
|
||||||
|
- Iterable[Parameter]: From model.parameters() - standard PyTorch usage
|
||||||
|
- Iterable[dict]: List of param groups with 'params' key and optional
|
||||||
|
'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies)
|
||||||
|
- dict[str, Parameter]: From dict(model.named_parameters()) for optimizers
|
||||||
|
that apply differential learning rates by parameter name (e.g., XVLA)
|
||||||
|
- dict[str, Iterable]: For multi-optimizer configs where each key maps to
|
||||||
|
a separate optimizer's parameters (e.g., SAC with actor/critic/temperature)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The optimizer or a dictionary of optimizers.
|
The optimizer or a dictionary of optimizers.
|
||||||
"""
|
"""
|
||||||
@@ -67,7 +90,7 @@ class AdamConfig(OptimizerConfig):
|
|||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
grad_clip_norm: float = 10.0
|
grad_clip_norm: float = 10.0
|
||||||
|
|
||||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||||
kwargs = asdict(self)
|
kwargs = asdict(self)
|
||||||
kwargs.pop("grad_clip_norm")
|
kwargs.pop("grad_clip_norm")
|
||||||
return torch.optim.Adam(params, **kwargs)
|
return torch.optim.Adam(params, **kwargs)
|
||||||
@@ -82,7 +105,7 @@ class AdamWConfig(OptimizerConfig):
|
|||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
grad_clip_norm: float = 10.0
|
grad_clip_norm: float = 10.0
|
||||||
|
|
||||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||||
kwargs = asdict(self)
|
kwargs = asdict(self)
|
||||||
kwargs.pop("grad_clip_norm")
|
kwargs.pop("grad_clip_norm")
|
||||||
return torch.optim.AdamW(params, **kwargs)
|
return torch.optim.AdamW(params, **kwargs)
|
||||||
@@ -98,7 +121,7 @@ class SGDConfig(OptimizerConfig):
|
|||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
grad_clip_norm: float = 10.0
|
grad_clip_norm: float = 10.0
|
||||||
|
|
||||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||||
kwargs = asdict(self)
|
kwargs = asdict(self)
|
||||||
kwargs.pop("grad_clip_norm")
|
kwargs.pop("grad_clip_norm")
|
||||||
return torch.optim.SGD(params, **kwargs)
|
return torch.optim.SGD(params, **kwargs)
|
||||||
@@ -139,21 +162,19 @@ class XVLAAdamWConfig(OptimizerConfig):
|
|||||||
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
||||||
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
||||||
|
|
||||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||||
"""
|
"""
|
||||||
Build AdamW optimizer with differential learning rates.
|
Build AdamW optimizer with differential learning rates.
|
||||||
|
|
||||||
Expects `named_parameters()` as input (dict of name -> param).
|
|
||||||
Applies:
|
|
||||||
- lr * 0.1 for all VLM-related parameters
|
|
||||||
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
|
|
||||||
- full lr for all other parameters
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Dictionary of parameter names to parameters (from named_parameters())
|
params: Must be a dict[str, Parameter] from dict(model.named_parameters())
|
||||||
|
or equivalent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If params is not a dict (e.g., from model.parameters())
|
||||||
"""
|
"""
|
||||||
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
||||||
|
|
||||||
@@ -174,7 +195,7 @@ class XVLAAdamWConfig(OptimizerConfig):
|
|||||||
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
||||||
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
||||||
|
|
||||||
param_groups = [
|
param_groups: list[dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
"params": vlm_group,
|
"params": vlm_group,
|
||||||
"lr": self.lr * 0.1,
|
"lr": self.lr * 0.1,
|
||||||
@@ -224,19 +245,25 @@ class MultiAdamConfig(OptimizerConfig):
|
|||||||
grad_clip_norm: float = 10.0
|
grad_clip_norm: float = 10.0
|
||||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
|
||||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]:
|
||||||
"""Build multiple Adam optimizers.
|
"""Build multiple Adam optimizers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names
|
||||||
The keys should match the keys in optimizer_groups
|
to iterables of parameters. The keys should match the keys in optimizer_groups.
|
||||||
|
Typically from policies that need separate optimizers (e.g., SAC with
|
||||||
|
actor/critic/temperature).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping parameter group names to their optimizers
|
Dictionary mapping parameter group names to their optimizers
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If params is not a dict
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs."
|
||||||
optimizers = {}
|
optimizers = {}
|
||||||
|
|
||||||
for name, params in params_dict.items():
|
for name, group_params in params.items():
|
||||||
# Get group-specific hyperparameters or use defaults
|
# Get group-specific hyperparameters or use defaults
|
||||||
group_config = self.optimizer_groups.get(name, {})
|
group_config = self.optimizer_groups.get(name, {})
|
||||||
|
|
||||||
@@ -248,7 +275,7 @@ class MultiAdamConfig(OptimizerConfig):
|
|||||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||||
}
|
}
|
||||||
|
|
||||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs)
|
||||||
|
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
num_warmup_steps: int
|
num_warmup_steps: int | None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user