mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
feat(devices): add lazy loading for 3rd party robots cameras and teleoperators (#2123)
* feat(devices): add lazy loading for 3rd party robots cameras and teleoperators Co-authored-by: Darko Lukić <lukicdarkoo@gmail.com> * feat(devices): load device class based on assumptions in naming * docs(devices): instructions for using 3rd party devices * docs: address review feedback * chore(docs): add example for 3rd party devices --------- Co-authored-by: Darko Lukić <lukicdarkoo@gmail.com>
This commit is contained in:
@@ -15,15 +15,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, Cv2Rotation
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
cameras: dict[str, Camera] = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if cfg.type == "opencv":
|
||||
from .opencv import OpenCVCamera
|
||||
|
||||
@@ -40,7 +44,10 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating camera {key} with config {cfg}: {e}") from e
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -14,13 +14,16 @@
|
||||
|
||||
import logging
|
||||
from pprint import pformat
|
||||
from typing import cast
|
||||
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if config.type == "koch_follower":
|
||||
from .koch_follower import KochFollower
|
||||
|
||||
@@ -66,7 +69,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
|
||||
return MockRobot(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
try:
|
||||
return cast(Robot, make_device_from_device_class(config))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating robot with config {config}: {e}") from e
|
||||
|
||||
|
||||
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
|
||||
|
||||
@@ -52,6 +52,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -83,6 +84,7 @@ def calibrate(cfg: CalibrateConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
calibrate()
|
||||
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ from lerobot.utils.control_utils import (
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
@@ -513,6 +514,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
record()
|
||||
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -126,6 +127,7 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
replay()
|
||||
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
@@ -215,6 +216,7 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
teleoperate()
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
@@ -29,6 +32,7 @@ class TeleopEvents(Enum):
|
||||
|
||||
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if config.type == "keyboard":
|
||||
from .keyboard import KeyboardTeleop
|
||||
|
||||
@@ -82,4 +86,7 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
|
||||
return Reachy2Teleoperator(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
try:
|
||||
return cast(Teleoperator, make_device_from_device_class(config))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating robot with config {config}: {e}") from e
|
||||
|
||||
@@ -15,6 +15,10 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from typing import Any
|
||||
|
||||
from draccus.choice_types import ChoiceRegistry
|
||||
|
||||
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
@@ -58,3 +62,93 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
|
||||
|
||||
_transformers_available = is_package_available("transformers")
|
||||
|
||||
|
||||
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
||||
"""
|
||||
Dynamically instantiates an object from its `ChoiceRegistry` configuration.
|
||||
|
||||
This factory uses the module path and class name from the `config` object's
|
||||
type to locate and instantiate the corresponding device class (not the config).
|
||||
It derives the device class name by removing a trailing 'Config' from the config
|
||||
class name and tries a few candidate modules where the device implementation is
|
||||
commonly located.
|
||||
"""
|
||||
if not isinstance(config, ChoiceRegistry):
|
||||
raise ValueError(f"Config should be an instance of `ChoiceRegistry`, got {type(config)}")
|
||||
|
||||
config_cls = config.__class__
|
||||
module_path = config_cls.__module__ # typical: lerobot_teleop_mydevice.config_mydevice
|
||||
config_name = config_cls.__name__ # typical: MyDeviceConfig
|
||||
|
||||
# Derive device class name (strip "Config")
|
||||
if not config_name.endswith("Config"):
|
||||
raise ValueError(f"Config class name '{config_name}' does not end with 'Config'")
|
||||
|
||||
device_class_name = config_name[:-6] # typical: MyDeviceConfig -> MyDevice
|
||||
|
||||
# Build candidate modules to search for the device class
|
||||
parts = module_path.split(".")
|
||||
parent_module = ".".join(parts[:-1]) if len(parts) > 1 else module_path
|
||||
candidates = [
|
||||
parent_module, # typical: lerobot_teleop_mydevice
|
||||
parent_module + "." + device_class_name.lower(), # typical: lerobot_teleop_mydevice.mydevice
|
||||
]
|
||||
|
||||
# handle modules named like "config_xxx" -> try replacing that piece with "xxx"
|
||||
last = parts[-1] if parts else ""
|
||||
if last.startswith("config_"):
|
||||
candidates.append(".".join(parts[:-1] + [last.replace("config_", "")]))
|
||||
|
||||
# de-duplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
candidates = [c for c in candidates if not (c in seen or seen.add(c))]
|
||||
|
||||
tried: list[str] = []
|
||||
for candidate in candidates:
|
||||
tried.append(candidate)
|
||||
try:
|
||||
module = importlib.import_module(candidate)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
if hasattr(module, device_class_name):
|
||||
cls = getattr(module, device_class_name)
|
||||
if callable(cls):
|
||||
try:
|
||||
return cls(config)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Failed to instantiate '{device_class_name}' from module '{candidate}': {e}"
|
||||
) from e
|
||||
|
||||
raise ImportError(
|
||||
f"Could not locate device class '{device_class_name}' for config '{config_name}'. "
|
||||
f"Tried modules: {tried}. Ensure your device class name is the config class name without "
|
||||
f"'Config' and that it's importable from one of those modules."
|
||||
)
|
||||
|
||||
|
||||
def register_third_party_devices() -> None:
|
||||
"""
|
||||
Discover and import third-party lerobot_* plugins so they can register themselves.
|
||||
|
||||
Scans top-level modules on sys.path for packages starting with
|
||||
'lerobot_robot_', 'lerobot_camera_' or 'lerobot_teleoperator_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_")
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
for module_info in pkgutil.iter_modules():
|
||||
name = module_info.name
|
||||
if name.startswith(prefixes):
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
imported.append(name)
|
||||
logging.info("Imported third-party plugin: %s", name)
|
||||
except Exception:
|
||||
logging.exception("Could not import third-party plugin: %s", name)
|
||||
failed.append(name)
|
||||
|
||||
logging.debug("Third-party plugin import summary: imported=%s failed=%s", imported, failed)
|
||||
|
||||
Reference in New Issue
Block a user