feat(policies): Allow users to register 3rd party policies - pip install lerobot_policy_mypolicy (#2308)

* feat: Register external policies

* ruff fix

* move policy util functions to policy factory

* refactor register_third_party_devices -> register_third_party_plugins

* feat: Update docs with bring your own policies

* Improve docs for new policies

* fix: Inconsistent quotation marks

* fix: Remove print statement

* fix: wrong base class name in documentation

* fix: Handle better how the models are parsed

* fix: precommit passing

* Update docs/source/bring_your_own_policies.mdx

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Daniel San José Pro <42489409+danielsanjosepro@users.noreply.github.com>

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Daniel San José Pro <42489409+danielsanjosepro@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Daniel San José Pro
2025-12-03 12:09:24 +01:00
committed by GitHub
parent 0b497fc37d
commit 9ec9ee781a
10 changed files with 271 additions and 14 deletions
+79 -3
View File
@@ -16,6 +16,7 @@
from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict
@@ -108,7 +109,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
return GrootPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
try:
return _get_policy_cls_from_policy_name(name=name)
except Exception as e:
raise ValueError(f"Policy type '{name}' is not available.") from e
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
@@ -151,7 +155,11 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
elif policy_type == "groot":
return GrootConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
class ProcessorConfigKwargs(TypedDict, total=False):
@@ -331,7 +339,13 @@ def make_pre_post_processors(
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
try:
processors = _make_processors_from_policy_config(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
return processors
@@ -425,3 +439,65 @@ def make_policy(
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
"""Get policy class from its registered name using dynamic imports.
This is used as a helper function to import policies from 3rd party lerobot plugins.
Args:
name: The name of the policy.
Returns:
The policy class corresponding to the given name.
"""
if name not in PreTrainedConfig.get_known_choices():
raise ValueError(
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
)
config_cls = PreTrainedConfig.get_choice_class(name)
config_cls_name = config_cls.__name__
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
module_path = config_cls.__module__.replace(
"configuration_", "modeling_"
) # e.g., configuration_diffusion -> modeling_diffusion
module = importlib.import_module(module_path)
policy_cls = getattr(module, cls_name)
return policy_cls
def _make_processors_from_policy_config(
config: PreTrainedConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a policy configuration using dynamic imports.
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
Args:
config: The policy configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""
policy_type = config.type
function_name = f"make_{policy_type}_pre_post_processors"
module_path = config.__class__.__module__.replace(
"configuration_", "processor_"
) # e.g., configuration_diffusion -> processor_diffusion
logging.debug(
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)
+2 -2
View File
@@ -52,7 +52,7 @@ from lerobot.teleoperators import ( # noqa: F401
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.utils import init_logging
@@ -84,7 +84,7 @@ def calibrate(cfg: CalibrateConfig):
def main():
register_third_party_devices()
register_third_party_plugins()
calibrate()
+2
View File
@@ -82,6 +82,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -792,6 +793,7 @@ def eval_policy_all(
def main():
init_logging()
register_third_party_plugins()
eval_main()
+2 -2
View File
@@ -118,7 +118,7 @@ from lerobot.utils.control_utils import (
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (
get_safe_torch_device,
@@ -512,7 +512,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
def main():
register_third_party_devices()
register_third_party_plugins()
record()
+2 -2
View File
@@ -61,7 +61,7 @@ from lerobot.robots import ( # noqa: F401
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (
init_logging,
@@ -127,7 +127,7 @@ def replay(cfg: ReplayConfig):
def main():
register_third_party_devices()
register_third_party_plugins()
replay()
+2 -2
View File
@@ -88,7 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@@ -217,7 +217,7 @@ def teleoperate(cfg: TeleoperateConfig):
def main():
register_third_party_devices()
register_third_party_plugins()
teleoperate()
+2
View File
@@ -36,6 +36,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
@@ -446,6 +447,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
def main():
register_third_party_plugins()
train()
+3 -3
View File
@@ -130,14 +130,14 @@ def make_device_from_device_class(config: ChoiceRegistry) -> Any:
)
def register_third_party_devices() -> None:
def register_third_party_plugins() -> None:
"""
Discover and import third-party lerobot_* plugins so they can register themselves.
Scans top-level modules on sys.path for packages starting with
'lerobot_robot_', 'lerobot_camera_' or 'lerobot_teleoperator_' and imports them.
'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them.
"""
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_")
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
imported: list[str] = []
failed: list[str] = []