From 16de8b3f19d7c70f908817d0b915e4ecedb75d6c Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 22 Jul 2025 01:08:39 +0200 Subject: [PATCH] Adding support for audio data recording and broadcasting for LeKiwi --- src/lerobot/datasets/compute_stats.py | 16 ++++++-- src/lerobot/datasets/lerobot_dataset.py | 19 ++++++--- src/lerobot/datasets/utils.py | 2 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 47 +++++++++++++++------- src/lerobot/scripts/lerobot_record.py | 2 +- tests/datasets/test_compute_stats.py | 20 ++++++--- 6 files changed, 74 insertions(+), 32 deletions(-) diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index bba1206a5..30c4b9262 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -15,7 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.datasets.utils import load_audio, load_image_as_numpy +from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] @@ -245,13 +245,18 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def sample_audio(audio_path: str) -> np.ndarray: - data = load_audio(audio_path) +def sample_audio_from_path(audio_path: str) -> np.ndarray: + data = load_audio_from_path(audio_path) sampled_indices = sample_indices(len(data)) return data[sampled_indices] +def sample_audio_from_data(data: np.ndarray) -> np.ndarray: + sampled_indices = sample_indices(len(data)) + return data[sampled_indices] + + def _reshape_stats_by_axis( stats: dict[str, np.ndarray], axis: int | tuple[int, ...] | None, @@ -520,7 +525,10 @@ def compute_episode_stats( axes_to_reduce = (0, 2, 3) keepdims = True elif features[key]["dtype"] == "audio": - ep_ft_array = sample_audio(data[0]) + try: + ep_ft_array = sample_audio_from_path(data[0]) + except TypeError: # Should only be triggered for LeKiwi robot + ep_ft_array = sample_audio_from_data(data) axes_to_reduce = 0 keepdims = True else: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index aa1fa6a0a..c46ca3034 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1279,11 +1279,14 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_image(frame[key], img_path, compress_level) self.episode_buffer[key].append(str(img_path)) elif self.features[key]["dtype"] == "audio": - if frame_index == 0: - audio_path = self._get_raw_audio_file_path( - episode_index=self.episode_buffer["episode_index"], audio_key=key - ) - self.episode_buffer[key].append(str(audio_path)) + if self.meta.robot_type == "lekiwi": + self.episode_buffer[key].append(frame[key]) + else: + if frame_index == 0: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"], audio_key=key + ) + self.episode_buffer[key].append(str(audio_path)) else: self.episode_buffer[key].append(frame[key]) @@ -1347,7 +1350,11 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + elif ft["dtype"] == "audio": + if self.meta.robot_type == "lekiwi": + episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0) continue episode_buffer[key] = np.stack(episode_buffer[key]) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 8ace0de48..ed2b77233 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -415,7 +415,7 @@ def load_image_as_numpy( return img_array -def load_audio(fpath: str | Path) -> np.ndarray: +def load_audio_from_path(fpath: str | Path) -> np.ndarray: audio_data, _ = read(fpath, dtype="float32") return audio_data diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 889a0e405..cda6f7c62 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -18,6 +18,7 @@ import base64 import json import logging from functools import cached_property +from time import perf_counter import cv2 import numpy as np @@ -58,8 +59,9 @@ class LeKiwiClient(Robot): self.zmq_observation_socket = None self.last_frames = {} - self.last_remote_state = {} + self.last_frame_timestamp = None + self.last_frame_delay = 0.0 # Define three speed levels and a current index self.speed_levels = [ @@ -139,6 +141,7 @@ class LeKiwiClient(Robot): if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN: raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.") + self.last_frame_timestamp = perf_counter() self._is_connected = True def calibrate(self) -> None: @@ -171,6 +174,8 @@ class LeKiwiClient(Robot): if last_msg is None: logging.warning("Poller indicated data, but failed to retrieve message.") + self.last_frame_delay = perf_counter() - self.last_frame_timestamp + self.last_frame_timestamp = perf_counter() return last_msg def _parse_observation_json(self, obs_string: str) -> RobotObservation | None: @@ -207,14 +212,16 @@ class LeKiwiClient(Robot): obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec} - # Decode images + # Decode images and audio data current_frames: dict[str, np.ndarray] = {} - for cam_name, image_b64 in observation.items(): - if cam_name not in self._cameras_ft: - continue - frame = self._decode_image_from_b64(image_b64) - if frame is not None: - current_frames[cam_name] = frame + for frame_name, frame_data in observation.items(): + if frame_name in self._cameras_ft: + image = self._decode_image_from_b64(frame_data) + if image is not None: + current_frames[frame_name] = image + elif frame_name in self._microphones_ft: + if frame_data is not None: + current_frames[frame_name] = frame_data return current_frames, obs_dict @@ -258,17 +265,27 @@ class LeKiwiClient(Robot): """ Capture observations from the remote robot: current follower arm positions, present wheel speeds (converted to body-frame velocities: x, y, theta), - and a camera frame. Receives over ZMQ, translate to body-frame vel + and cameras and microphones data. Receives over ZMQ, translate to body-frame vel """ frames, obs_dict = self._get_data() - # Loop over each configured camera - for cam_name, frame in frames.items(): - if frame is None: - logging.warning("Frame is None") - frame = np.zeros((640, 480, 3), dtype=np.uint8) - obs_dict[cam_name] = frame + # Loop over each configured camera and microphone + for frame_name, frame_data in frames.items(): + if frame_data is None: + if frame_name in self._cameras_ft: + logging.warning("Image frame is None") + image = np.zeros((640, 480, 3), dtype=np.uint8) + obs_dict[frame_name] = image + elif frame_name in self._microphones_ft: + logging.warning("Audio frame is None") + obs_dict[frame_name] = np.zeros( + ( + int(self._microphones_ft[frame_name][0] * self.last_frame_delay), + self._microphones_ft[frame_name][1], + ), + dtype=np.float32, + ) return obs_dict diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 41d78172e..42bb7ccf9 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -313,7 +313,7 @@ def record_loop( preprocessor.reset() postprocessor.reset() - if dataset is not None: + if dataset is not None and robot.name != "lekiwi": for microphone_key, microphone in robot.microphones.items(): dataset.add_microphone_recording(microphone, microphone_key) else: diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 62531685b..1a18e14fa 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -26,7 +26,8 @@ from lerobot.datasets.compute_stats import ( compute_episode_stats, estimate_num_samples, get_feature_stats, - sample_audio, + sample_audio_from_data, + sample_audio_from_path, sample_images, sample_indices, ) @@ -78,10 +79,19 @@ def test_sample_images(mock_load): assert len(images) == estimate_num_samples(100) -@patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio) -def test_sample_audio(mock_load): +@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio) +def test_sample_audio_from_path(mock_load): audio_path = "audio.wav" - audio_samples = sample_audio(audio_path) + audio_samples = sample_audio_from_path(audio_path) + assert isinstance(audio_samples, np.ndarray) + assert audio_samples.shape[1] == 2 + assert audio_samples.dtype == np.float32 + assert len(audio_samples) == estimate_num_samples(16000) + + +def test_sample_audio_from_data(mock_load): + audio_data = np.ones((16000, 2), dtype=np.float32) + audio_samples = sample_audio_from_data(audio_data) assert isinstance(audio_samples, np.ndarray) assert audio_samples.shape[1] == 2 assert audio_samples.dtype == np.float32 @@ -179,7 +189,7 @@ def test_compute_episode_stats(): with ( patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy), - patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio), + patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio), ): stats = compute_episode_stats(episode_data, features)