refactor(config): Move device & amp args to PreTrainedConfig (#812)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma
2025-03-06 17:59:28 +01:00
committed by GitHub
parent 10706ed753
commit 5e9473806c
19 changed files with 62 additions and 136 deletions
+3 -6
View File
@@ -16,7 +16,6 @@
import logging
import torch
from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
@@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
def make_policy(
cfg: PreTrainedConfig,
device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
@@ -88,7 +86,6 @@ def make_policy(
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
@@ -96,7 +93,7 @@ def make_policy(
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
@@ -111,7 +108,7 @@ def make_policy(
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps":
if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
@@ -145,7 +142,7 @@ def make_policy(
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(device)
policy.to(cfg.device)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")