diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 0dd9db516..3a01aee88 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import importlib.metadata import logging -import pkgutil from typing import Any from draccus.choice_types import ChoiceRegistry @@ -132,24 +132,30 @@ def make_device_from_device_class(config: ChoiceRegistry) -> Any: def register_third_party_plugins() -> None: """ - Discover and import third-party lerobot_* plugins so they can register themselves. + Discover and import third-party LeRobot plugins so they can register themselves. - Scans top-level modules on sys.path for packages starting with - 'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them. + This function uses `importlib.metadata` to find packages installed in the environment + (including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_', + 'lerobot_teleoperator_', or 'lerobot_policy_' and imports them. """ prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_") imported: list[str] = [] failed: list[str] = [] - for module_info in pkgutil.iter_modules(): - name = module_info.name - if name.startswith(prefixes): - try: - importlib.import_module(name) - imported.append(name) - logging.info("Imported third-party plugin: %s", name) - except Exception: - logging.exception("Could not import third-party plugin: %s", name) - failed.append(name) + def attempt_import(module_name: str): + try: + importlib.import_module(module_name) + imported.append(module_name) + logging.info("Imported third-party plugin: %s", module_name) + except Exception: + logging.exception("Could not import third-party plugin: %s", module_name) + failed.append(module_name) + + for dist in importlib.metadata.distributions(): + dist_name = dist.metadata.get("Name") + if not dist_name: + continue + if dist_name.startswith(prefixes): + attempt_import(dist_name) logging.debug("Third-party plugin import summary: imported=%s failed=%s", imported, failed)