diff --git a/src/lerobot/robots/config.py b/src/lerobot/robots/config.py index a85a83169..39799a45a 100644 --- a/src/lerobot/robots/config.py +++ b/src/lerobot/robots/config.py @@ -34,6 +34,13 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC): raise ValueError( f"Specifying '{attr}' is required for the camera to be used in a robot" ) + if hasattr(self, "microphones") and self.microphones: + for _, config in self.microphones.items(): + for attr in ["sampling_rate", "channels"]: + if getattr(config, attr) is None: + raise ValueError( + f"Specifying '{attr}' is required for the microphone to be used in a robot" + ) @property def type(self) -> str: diff --git a/src/lerobot/robots/hope_jr/config_hope_jr.py b/src/lerobot/robots/hope_jr/config_hope_jr.py index f2af5f47c..b55114739 100644 --- a/src/lerobot/robots/hope_jr/config_hope_jr.py +++ b/src/lerobot/robots/hope_jr/config_hope_jr.py @@ -17,6 +17,7 @@ from dataclasses import dataclass, field from lerobot.cameras import CameraConfig +from lerobot.microphones import MicrophoneConfig from ..config import RobotConfig @@ -31,6 +32,8 @@ class HopeJrHandConfig(RobotConfig): cameras: dict[str, CameraConfig] = field(default_factory=dict) + microphones: dict[str, MicrophoneConfig] = field(default_factory=dict) + def __post_init__(self): super().__post_init__() if self.side not in ["right", "left"]: @@ -49,3 +52,5 @@ class HopeJrArmConfig(RobotConfig): max_relative_target: float | dict[str, float] | None = None cameras: dict[str, CameraConfig] = field(default_factory=dict) + + microphones: dict[str, MicrophoneConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/koch_follower/config_koch_follower.py b/src/lerobot/robots/koch_follower/config_koch_follower.py index 02a95ef4e..7bd708800 100644 --- a/src/lerobot/robots/koch_follower/config_koch_follower.py +++ b/src/lerobot/robots/koch_follower/config_koch_follower.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field from lerobot.cameras import CameraConfig +from lerobot.microphones import MicrophoneConfig from ..config import RobotConfig @@ -35,5 +36,8 @@ class KochFollowerConfig(RobotConfig): # cameras cameras: dict[str, CameraConfig] = field(default_factory=dict) + # microphones + microphones: dict[str, MicrophoneConfig] = field(default_factory=dict) + # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index fee0adba9..ece2a047b 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -19,6 +19,7 @@ import time from functools import cached_property from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.microphones.utils import make_microphones_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DynamixelMotorsBus, @@ -61,6 +62,7 @@ class KochFollower(Robot): calibration=self.calibration, ) self.cameras = make_cameras_from_configs(config.cameras) + self.microphones = make_microphones_from_configs(config.microphones) @property def _motors_ft(self) -> dict[str, type]: @@ -72,9 +74,16 @@ class KochFollower(Robot): cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras } + @property + def _microphones_ft(self) -> dict[str, tuple]: + return { + mic: (self.config.microphones[mic].sampling_rate, self.config.microphones[mic].channels) + for mic in self.microphones + } + @cached_property def observation_features(self) -> dict[str, type | tuple]: - return {**self._motors_ft, **self._cameras_ft} + return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft} @cached_property def action_features(self) -> dict[str, type]: @@ -82,7 +91,11 @@ class KochFollower(Robot): @property def is_connected(self) -> bool: - return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + return ( + self.bus.is_connected + and all(cam.is_connected for cam in self.cameras.values()) + and all(mic.is_connected for mic in self.microphones.values()) + ) @check_if_already_connected def connect(self, calibrate: bool = True) -> None: @@ -101,6 +114,9 @@ class KochFollower(Robot): for cam in self.cameras.values(): cam.connect() + for mic in self.microphones.values(): + mic.connect() + self.configure() logger.info(f"{self} connected.") @@ -232,5 +248,7 @@ class KochFollower(Robot): self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() + for mic in self.microphones.values(): + mic.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/lekiwi/config_lekiwi.py b/src/lerobot/robots/lekiwi/config_lekiwi.py index acaf5f0ec..fcfe03ff3 100644 --- a/src/lerobot/robots/lekiwi/config_lekiwi.py +++ b/src/lerobot/robots/lekiwi/config_lekiwi.py @@ -16,6 +16,7 @@ from dataclasses import dataclass, field from lerobot.cameras.configs import CameraConfig, Cv2Rotation from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.microphones import MicrophoneConfig from ..config import RobotConfig @@ -45,6 +46,8 @@ class LeKiwiConfig(RobotConfig): cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) + microphones: dict[str, MicrophoneConfig] = field(default_factory=dict) + # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False @@ -92,5 +95,7 @@ class LeKiwiClientConfig(RobotConfig): cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) + microphones: dict[str, MicrophoneConfig] = field(default_factory=dict) + polling_timeout_ms: int = 15 connect_timeout_s: int = 5 diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 54848f49d..52d269939 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -23,6 +23,7 @@ from typing import Any import numpy as np from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.microphones.utils import make_microphones_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, @@ -73,6 +74,7 @@ class LeKiwi(Robot): self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")] self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")] self.cameras = make_cameras_from_configs(config.cameras) + self.microphones = make_microphones_from_configs(config.microphones) @property def _state_ft(self) -> dict[str, type]: @@ -97,9 +99,16 @@ class LeKiwi(Robot): cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras } + @property + def _microphones_ft(self) -> dict[str, tuple]: + return { + mic: (self.config.microphones[mic].sampling_rate, self.config.microphones[mic].channels) + for mic in self.microphones + } + @cached_property def observation_features(self) -> dict[str, type | tuple]: - return {**self._state_ft, **self._cameras_ft} + return {**self._state_ft, **self._cameras_ft, **self._microphones_ft} @cached_property def action_features(self) -> dict[str, type]: @@ -107,7 +116,11 @@ class LeKiwi(Robot): @property def is_connected(self) -> bool: - return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + return ( + self.bus.is_connected + and all(cam.is_connected for cam in self.cameras.values()) + and all(mic.is_connected for mic in self.microphones.values()) + ) @check_if_already_connected def connect(self, calibrate: bool = True) -> None: @@ -121,6 +134,9 @@ class LeKiwi(Robot): for cam in self.cameras.values(): cam.connect() + for mic in self.microphones.values(): + mic.connect() + self.configure() logger.info(f"{self} connected.") @@ -413,5 +429,7 @@ class LeKiwi(Robot): self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() + for mic in self.microphones.values(): + mic.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 1d5ea64a6..43b925bcf 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -97,9 +97,13 @@ class LeKiwiClient(Robot): def _cameras_ft(self) -> dict[str, tuple[int, int, int]]: return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()} + @cached_property + def _microphones_ft(self) -> dict[str, tuple]: + return {name: (cfg.sampling_rate, cfg.channels) for name, cfg in self.config.microphones.items()} + @cached_property def observation_features(self) -> dict[str, type | tuple]: - return {**self._state_ft, **self._cameras_ft} + return {**self._state_ft, **self._cameras_ft, **self._microphones_ft} @cached_property def action_features(self) -> dict[str, type]: diff --git a/src/lerobot/robots/so_follower/config_so_follower.py b/src/lerobot/robots/so_follower/config_so_follower.py index e9ce27123..3a83b8b00 100644 --- a/src/lerobot/robots/so_follower/config_so_follower.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, field from typing import TypeAlias from lerobot.cameras import CameraConfig +from lerobot.microphones import MicrophoneConfig from ..config import RobotConfig @@ -39,6 +40,9 @@ class SOFollowerConfig: # cameras cameras: dict[str, CameraConfig] = field(default_factory=dict) + # microphones + microphones: dict[str, MicrophoneConfig] = field(default_factory=dict) + # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index b4d11fe3f..eb8d69498 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -20,6 +20,7 @@ from functools import cached_property from typing import TypeAlias from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.microphones.utils import make_microphones_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, @@ -62,6 +63,7 @@ class SOFollower(Robot): calibration=self.calibration, ) self.cameras = make_cameras_from_configs(config.cameras) + self.microphones = make_microphones_from_configs(config.microphones) @property def _motors_ft(self) -> dict[str, type]: @@ -73,9 +75,16 @@ class SOFollower(Robot): cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras } + @property + def _microphones_ft(self) -> dict[str, tuple]: + return { + mic: (self.config.microphones[mic].sampling_rate, self.config.microphones[mic].channels) + for mic in self.microphones + } + @cached_property def observation_features(self) -> dict[str, type | tuple]: - return {**self._motors_ft, **self._cameras_ft} + return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft} @cached_property def action_features(self) -> dict[str, type]: @@ -83,7 +92,11 @@ class SOFollower(Robot): @property def is_connected(self) -> bool: - return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + return ( + self.bus.is_connected + and all(cam.is_connected for cam in self.cameras.values()) + and all(mic.is_connected for mic in self.microphones.values()) + ) @check_if_already_connected def connect(self, calibrate: bool = True) -> None: @@ -102,6 +115,9 @@ class SOFollower(Robot): for cam in self.cameras.values(): cam.connect() + for mic in self.microphones.values(): + mic.connect() + self.configure() logger.info(f"{self} connected.") @@ -226,6 +242,8 @@ class SOFollower(Robot): self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() + for mic in self.microphones.values(): + mic.disconnect() logger.info(f"{self} disconnected.")