mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
feat(policies): Allow users to register 3rd party policies - pip install lerobot_policy_mypolicy (#2308)
* feat: Register external policies * ruff fix * move policy util functions to policy factory * refactor register_third_party_devices -> register_third_party_plugins * feat: Update docs with bring your own policies * Improve docs for new policies * fix: Inconsistent quotation marks * fix: Remove print statement * fix: wrong base class name in documentation * fix: Handle better how the models are parsed * fix: precommit passing * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Signed-off-by: Daniel San José Pro <42489409+danielsanjosepro@users.noreply.github.com> --------- Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Signed-off-by: Daniel San José Pro <42489409+danielsanjosepro@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
0b497fc37d
commit
9ec9ee781a
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
@@ -108,7 +109,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
return GrootPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Policy type '{name}' is not available.") from e
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
@@ -151,7 +155,11 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
return config_cls(**kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
@@ -331,7 +339,13 @@ def make_pre_post_processors(
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
|
||||
|
||||
return processors
|
||||
|
||||
@@ -425,3 +439,65 @@ def make_policy(
|
||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
|
||||
"""Get policy class from its registered name using dynamic imports.
|
||||
|
||||
This is used as a helper function to import policies from 3rd party lerobot plugins.
|
||||
|
||||
Args:
|
||||
name: The name of the policy.
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
"""
|
||||
if name not in PreTrainedConfig.get_known_choices():
|
||||
raise ValueError(
|
||||
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
|
||||
)
|
||||
|
||||
config_cls = PreTrainedConfig.get_choice_class(name)
|
||||
config_cls_name = config_cls.__name__
|
||||
|
||||
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
|
||||
if model_name == config_cls_name:
|
||||
raise ValueError(
|
||||
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
|
||||
f"Make sure it ends with 'Config'!"
|
||||
)
|
||||
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
|
||||
module_path = config_cls.__module__.replace(
|
||||
"configuration_", "modeling_"
|
||||
) # e.g., configuration_diffusion -> modeling_diffusion
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_cls = getattr(module, cls_name)
|
||||
return policy_cls
|
||||
|
||||
|
||||
def _make_processors_from_policy_config(
|
||||
config: PreTrainedConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Create pre- and post-processors from a policy configuration using dynamic imports.
|
||||
|
||||
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
|
||||
|
||||
Args:
|
||||
config: The policy configuration object.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
"""
|
||||
|
||||
policy_type = config.type
|
||||
function_name = f"make_{policy_type}_pre_post_processors"
|
||||
module_path = config.__class__.__module__.replace(
|
||||
"configuration_", "processor_"
|
||||
) # e.g., configuration_diffusion -> processor_diffusion
|
||||
logging.debug(
|
||||
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
return function(config, dataset_stats=dataset_stats)
|
||||
|
||||
@@ -52,7 +52,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def calibrate(cfg: CalibrateConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
calibrate()
|
||||
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -792,6 +793,7 @@ def eval_policy_all(
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
register_third_party_plugins()
|
||||
eval_main()
|
||||
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ from lerobot.utils.control_utils import (
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
@@ -512,7 +512,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
record()
|
||||
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -127,7 +127,7 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
replay()
|
||||
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
@@ -217,7 +217,7 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
teleoperate()
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.train_utils import (
|
||||
@@ -446,6 +447,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_plugins()
|
||||
train()
|
||||
|
||||
|
||||
|
||||
@@ -130,14 +130,14 @@ def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
||||
)
|
||||
|
||||
|
||||
def register_third_party_devices() -> None:
|
||||
def register_third_party_plugins() -> None:
|
||||
"""
|
||||
Discover and import third-party lerobot_* plugins so they can register themselves.
|
||||
|
||||
Scans top-level modules on sys.path for packages starting with
|
||||
'lerobot_robot_', 'lerobot_camera_' or 'lerobot_teleoperator_' and imports them.
|
||||
'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_")
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user