Integrate microphones in Robot class

This commit is contained in:
CarolinePascal
2025-03-28 17:15:19 +01:00
parent 7e5f3b35e9
commit d998660aa1
9 changed files with 90 additions and 7 deletions
+7
View File
@@ -34,6 +34,13 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
raise ValueError( raise ValueError(
f"Specifying '{attr}' is required for the camera to be used in a robot" f"Specifying '{attr}' is required for the camera to be used in a robot"
) )
if hasattr(self, "microphones") and self.microphones:
for _, config in self.microphones.items():
for attr in ["sampling_rate", "channels"]:
if getattr(config, attr) is None:
raise ValueError(
f"Specifying '{attr}' is required for the microphone to be used in a robot"
)
@property @property
def type(self) -> str: def type(self) -> str:
@@ -17,6 +17,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig from lerobot.cameras import CameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig from ..config import RobotConfig
@@ -31,6 +32,8 @@ class HopeJrHandConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.side not in ["right", "left"]: if self.side not in ["right", "left"]:
@@ -49,3 +52,5 @@ class HopeJrArmConfig(RobotConfig):
max_relative_target: float | dict[str, float] | None = None max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
@@ -15,6 +15,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig from lerobot.cameras import CameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig from ..config import RobotConfig
@@ -35,5 +36,8 @@ class KochFollowerConfig(RobotConfig):
# cameras # cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
# microphones
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset # Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False use_degrees: bool = False
@@ -19,6 +19,7 @@ import time
from functools import cached_property from functools import cached_property
from lerobot.cameras.utils import make_cameras_from_configs from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.microphones.utils import make_microphones_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import ( from lerobot.motors.dynamixel import (
DynamixelMotorsBus, DynamixelMotorsBus,
@@ -61,6 +62,7 @@ class KochFollower(Robot):
calibration=self.calibration, calibration=self.calibration,
) )
self.cameras = make_cameras_from_configs(config.cameras) self.cameras = make_cameras_from_configs(config.cameras)
self.microphones = make_microphones_from_configs(config.microphones)
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:
@@ -72,9 +74,16 @@ class KochFollower(Robot):
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
} }
@property
def _microphones_ft(self) -> dict[str, tuple]:
return {
mic: (self.config.microphones[mic].sampling_rate, self.config.microphones[mic].channels)
for mic in self.microphones
}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft} return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
@cached_property @cached_property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
@@ -82,7 +91,11 @@ class KochFollower(Robot):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) return (
self.bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
and all(mic.is_connected for mic in self.microphones.values())
)
@check_if_already_connected @check_if_already_connected
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
@@ -101,6 +114,9 @@ class KochFollower(Robot):
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.connect() cam.connect()
for mic in self.microphones.values():
mic.connect()
self.configure() self.configure()
logger.info(f"{self} connected.") logger.info(f"{self} connected.")
@@ -232,5 +248,7 @@ class KochFollower(Robot):
self.bus.disconnect(self.config.disable_torque_on_disconnect) self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.disconnect() cam.disconnect()
for mic in self.microphones.values():
mic.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")
@@ -16,6 +16,7 @@ from dataclasses import dataclass, field
from lerobot.cameras.configs import CameraConfig, Cv2Rotation from lerobot.cameras.configs import CameraConfig, Cv2Rotation
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig from ..config import RobotConfig
@@ -45,6 +46,8 @@ class LeKiwiConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset # Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False use_degrees: bool = False
@@ -92,5 +95,7 @@ class LeKiwiClientConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
polling_timeout_ms: int = 15 polling_timeout_ms: int = 15
connect_timeout_s: int = 5 connect_timeout_s: int = 5
+20 -2
View File
@@ -23,6 +23,7 @@ from typing import Any
import numpy as np import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.microphones.utils import make_microphones_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import ( from lerobot.motors.feetech import (
FeetechMotorsBus, FeetechMotorsBus,
@@ -73,6 +74,7 @@ class LeKiwi(Robot):
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")] self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")] self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
self.cameras = make_cameras_from_configs(config.cameras) self.cameras = make_cameras_from_configs(config.cameras)
self.microphones = make_microphones_from_configs(config.microphones)
@property @property
def _state_ft(self) -> dict[str, type]: def _state_ft(self) -> dict[str, type]:
@@ -97,9 +99,16 @@ class LeKiwi(Robot):
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
} }
@property
def _microphones_ft(self) -> dict[str, tuple]:
return {
mic: (self.config.microphones[mic].sampling_rate, self.config.microphones[mic].channels)
for mic in self.microphones
}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
return {**self._state_ft, **self._cameras_ft} return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
@cached_property @cached_property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
@@ -107,7 +116,11 @@ class LeKiwi(Robot):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) return (
self.bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
and all(mic.is_connected for mic in self.microphones.values())
)
@check_if_already_connected @check_if_already_connected
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
@@ -121,6 +134,9 @@ class LeKiwi(Robot):
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.connect() cam.connect()
for mic in self.microphones.values():
mic.connect()
self.configure() self.configure()
logger.info(f"{self} connected.") logger.info(f"{self} connected.")
@@ -413,5 +429,7 @@ class LeKiwi(Robot):
self.bus.disconnect(self.config.disable_torque_on_disconnect) self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.disconnect() cam.disconnect()
for mic in self.microphones.values():
mic.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")
+5 -1
View File
@@ -97,9 +97,13 @@ class LeKiwiClient(Robot):
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]: def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()} return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
@cached_property
def _microphones_ft(self) -> dict[str, tuple]:
return {name: (cfg.sampling_rate, cfg.channels) for name, cfg in self.config.microphones.items()}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
return {**self._state_ft, **self._cameras_ft} return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
@cached_property @cached_property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
@@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from typing import TypeAlias from typing import TypeAlias
from lerobot.cameras import CameraConfig from lerobot.cameras import CameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig from ..config import RobotConfig
@@ -39,6 +40,9 @@ class SOFollowerConfig:
# cameras # cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
# microphones
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset # Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False use_degrees: bool = False
+20 -2
View File
@@ -20,6 +20,7 @@ from functools import cached_property
from typing import TypeAlias from typing import TypeAlias
from lerobot.cameras.utils import make_cameras_from_configs from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.microphones.utils import make_microphones_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import ( from lerobot.motors.feetech import (
FeetechMotorsBus, FeetechMotorsBus,
@@ -62,6 +63,7 @@ class SOFollower(Robot):
calibration=self.calibration, calibration=self.calibration,
) )
self.cameras = make_cameras_from_configs(config.cameras) self.cameras = make_cameras_from_configs(config.cameras)
self.microphones = make_microphones_from_configs(config.microphones)
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:
@@ -73,9 +75,16 @@ class SOFollower(Robot):
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
} }
@property
def _microphones_ft(self) -> dict[str, tuple]:
return {
mic: (self.config.microphones[mic].sampling_rate, self.config.microphones[mic].channels)
for mic in self.microphones
}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft} return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
@cached_property @cached_property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
@@ -83,7 +92,11 @@ class SOFollower(Robot):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) return (
self.bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
and all(mic.is_connected for mic in self.microphones.values())
)
@check_if_already_connected @check_if_already_connected
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
@@ -102,6 +115,9 @@ class SOFollower(Robot):
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.connect() cam.connect()
for mic in self.microphones.values():
mic.connect()
self.configure() self.configure()
logger.info(f"{self} connected.") logger.info(f"{self} connected.")
@@ -226,6 +242,8 @@ class SOFollower(Robot):
self.bus.disconnect(self.config.disable_torque_on_disconnect) self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.disconnect() cam.disconnect()
for mic in self.microphones.values():
mic.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")