Adding support for audio data recording and broadcasting for LeKiwi

This commit is contained in:
CarolinePascal
2025-07-22 01:08:39 +02:00
parent 580008663b
commit 16de8b3f19
6 changed files with 74 additions and 32 deletions
+12 -4
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import numpy as np
from lerobot.datasets.utils import load_audio, load_image_as_numpy
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -245,13 +245,18 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def sample_audio(audio_path: str) -> np.ndarray:
data = load_audio(audio_path)
def sample_audio_from_path(audio_path: str) -> np.ndarray:
data = load_audio_from_path(audio_path)
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def _reshape_stats_by_axis(
stats: dict[str, np.ndarray],
axis: int | tuple[int, ...] | None,
@@ -520,7 +525,10 @@ def compute_episode_stats(
axes_to_reduce = (0, 2, 3)
keepdims = True
elif features[key]["dtype"] == "audio":
ep_ft_array = sample_audio(data[0])
try:
ep_ft_array = sample_audio_from_path(data[0])
except TypeError: # Should only be triggered for LeKiwi robot
ep_ft_array = sample_audio_from_data(data)
axes_to_reduce = 0
keepdims = True
else:
+13 -6
View File
@@ -1279,11 +1279,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
elif self.features[key]["dtype"] == "audio":
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
if self.meta.robot_type == "lekiwi":
self.episode_buffer[key].append(frame[key])
else:
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
else:
self.episode_buffer[key].append(frame[key])
@@ -1347,7 +1350,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif ft["dtype"] == "audio":
if self.meta.robot_type == "lekiwi":
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
continue
episode_buffer[key] = np.stack(episode_buffer[key])
+1 -1
View File
@@ -415,7 +415,7 @@ def load_image_as_numpy(
return img_array
def load_audio(fpath: str | Path) -> np.ndarray:
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
audio_data, _ = read(fpath, dtype="float32")
return audio_data
+32 -15
View File
@@ -18,6 +18,7 @@ import base64
import json
import logging
from functools import cached_property
from time import perf_counter
import cv2
import numpy as np
@@ -58,8 +59,9 @@ class LeKiwiClient(Robot):
self.zmq_observation_socket = None
self.last_frames = {}
self.last_remote_state = {}
self.last_frame_timestamp = None
self.last_frame_delay = 0.0
# Define three speed levels and a current index
self.speed_levels = [
@@ -139,6 +141,7 @@ class LeKiwiClient(Robot):
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
self.last_frame_timestamp = perf_counter()
self._is_connected = True
def calibrate(self) -> None:
@@ -171,6 +174,8 @@ class LeKiwiClient(Robot):
if last_msg is None:
logging.warning("Poller indicated data, but failed to retrieve message.")
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
self.last_frame_timestamp = perf_counter()
return last_msg
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
@@ -207,14 +212,16 @@ class LeKiwiClient(Robot):
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
# Decode images
# Decode images and audio data
current_frames: dict[str, np.ndarray] = {}
for cam_name, image_b64 in observation.items():
if cam_name not in self._cameras_ft:
continue
frame = self._decode_image_from_b64(image_b64)
if frame is not None:
current_frames[cam_name] = frame
for frame_name, frame_data in observation.items():
if frame_name in self._cameras_ft:
image = self._decode_image_from_b64(frame_data)
if image is not None:
current_frames[frame_name] = image
elif frame_name in self._microphones_ft:
if frame_data is not None:
current_frames[frame_name] = frame_data
return current_frames, obs_dict
@@ -258,17 +265,27 @@ class LeKiwiClient(Robot):
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
and a camera frame. Receives over ZMQ, translate to body-frame vel
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
"""
frames, obs_dict = self._get_data()
# Loop over each configured camera
for cam_name, frame in frames.items():
if frame is None:
logging.warning("Frame is None")
frame = np.zeros((640, 480, 3), dtype=np.uint8)
obs_dict[cam_name] = frame
# Loop over each configured camera and microphone
for frame_name, frame_data in frames.items():
if frame_data is None:
if frame_name in self._cameras_ft:
logging.warning("Image frame is None")
image = np.zeros((640, 480, 3), dtype=np.uint8)
obs_dict[frame_name] = image
elif frame_name in self._microphones_ft:
logging.warning("Audio frame is None")
obs_dict[frame_name] = np.zeros(
(
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
self._microphones_ft[frame_name][1],
),
dtype=np.float32,
)
return obs_dict
+1 -1
View File
@@ -313,7 +313,7 @@ def record_loop(
preprocessor.reset()
postprocessor.reset()
if dataset is not None:
if dataset is not None and robot.name != "lekiwi":
for microphone_key, microphone in robot.microphones.items():
dataset.add_microphone_recording(microphone, microphone_key)
else:
+15 -5
View File
@@ -26,7 +26,8 @@ from lerobot.datasets.compute_stats import (
compute_episode_stats,
estimate_num_samples,
get_feature_stats,
sample_audio,
sample_audio_from_data,
sample_audio_from_path,
sample_images,
sample_indices,
)
@@ -78,10 +79,19 @@ def test_sample_images(mock_load):
assert len(images) == estimate_num_samples(100)
@patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
def test_sample_audio(mock_load):
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
def test_sample_audio_from_path(mock_load):
audio_path = "audio.wav"
audio_samples = sample_audio(audio_path)
audio_samples = sample_audio_from_path(audio_path)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000)
def test_sample_audio_from_data(mock_load):
audio_data = np.ones((16000, 2), dtype=np.float32)
audio_samples = sample_audio_from_data(audio_data)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
@@ -179,7 +189,7 @@ def test_compute_episode_stats():
with (
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio),
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
):
stats = compute_episode_stats(episode_data, features)