mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Adding support for audio data recording and broadcasting for LeKiwi
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_audio, load_image_as_numpy
|
||||
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
@@ -245,13 +245,18 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def sample_audio(audio_path: str) -> np.ndarray:
|
||||
data = load_audio(audio_path)
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
@@ -520,7 +525,10 @@ def compute_episode_stats(
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
ep_ft_array = sample_audio(data[0])
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: # Should only be triggered for LeKiwi robot
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
|
||||
@@ -1279,11 +1279,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self.features[key]["dtype"] == "audio":
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
if self.meta.robot_type == "lekiwi":
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
else:
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
@@ -1347,7 +1350,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["dtype"] == "audio":
|
||||
if self.meta.robot_type == "lekiwi":
|
||||
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
|
||||
@@ -415,7 +415,7 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def load_audio(fpath: str | Path) -> np.ndarray:
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
return audio_data
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from time import perf_counter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -58,8 +59,9 @@ class LeKiwiClient(Robot):
|
||||
self.zmq_observation_socket = None
|
||||
|
||||
self.last_frames = {}
|
||||
|
||||
self.last_remote_state = {}
|
||||
self.last_frame_timestamp = None
|
||||
self.last_frame_delay = 0.0
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
@@ -139,6 +141,7 @@ class LeKiwiClient(Robot):
|
||||
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
||||
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
self._is_connected = True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
@@ -171,6 +174,8 @@ class LeKiwiClient(Robot):
|
||||
if last_msg is None:
|
||||
logging.warning("Poller indicated data, but failed to retrieve message.")
|
||||
|
||||
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
return last_msg
|
||||
|
||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||
@@ -207,14 +212,16 @@ class LeKiwiClient(Robot):
|
||||
|
||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images
|
||||
# Decode images and audio data
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
for cam_name, image_b64 in observation.items():
|
||||
if cam_name not in self._cameras_ft:
|
||||
continue
|
||||
frame = self._decode_image_from_b64(image_b64)
|
||||
if frame is not None:
|
||||
current_frames[cam_name] = frame
|
||||
for frame_name, frame_data in observation.items():
|
||||
if frame_name in self._cameras_ft:
|
||||
image = self._decode_image_from_b64(frame_data)
|
||||
if image is not None:
|
||||
current_frames[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
if frame_data is not None:
|
||||
current_frames[frame_name] = frame_data
|
||||
|
||||
return current_frames, obs_dict
|
||||
|
||||
@@ -258,17 +265,27 @@ class LeKiwiClient(Robot):
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
|
||||
"""
|
||||
|
||||
frames, obs_dict = self._get_data()
|
||||
|
||||
# Loop over each configured camera
|
||||
for cam_name, frame in frames.items():
|
||||
if frame is None:
|
||||
logging.warning("Frame is None")
|
||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = frame
|
||||
# Loop over each configured camera and microphone
|
||||
for frame_name, frame_data in frames.items():
|
||||
if frame_data is None:
|
||||
if frame_name in self._cameras_ft:
|
||||
logging.warning("Image frame is None")
|
||||
image = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
logging.warning("Audio frame is None")
|
||||
obs_dict[frame_name] = np.zeros(
|
||||
(
|
||||
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
|
||||
self._microphones_ft[frame_name][1],
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
if dataset is not None:
|
||||
if dataset is not None and robot.name != "lekiwi":
|
||||
for microphone_key, microphone in robot.microphones.items():
|
||||
dataset.add_microphone_recording(microphone, microphone_key)
|
||||
else:
|
||||
|
||||
@@ -26,7 +26,8 @@ from lerobot.datasets.compute_stats import (
|
||||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_audio,
|
||||
sample_audio_from_data,
|
||||
sample_audio_from_path,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
@@ -78,10 +79,19 @@ def test_sample_images(mock_load):
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
@patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
|
||||
def test_sample_audio(mock_load):
|
||||
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||
def test_sample_audio_from_path(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
audio_samples = sample_audio(audio_path)
|
||||
audio_samples = sample_audio_from_path(audio_path)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_sample_audio_from_data(mock_load):
|
||||
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||
audio_samples = sample_audio_from_data(audio_data)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
@@ -179,7 +189,7 @@ def test_compute_episode_stats():
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
||||
patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio),
|
||||
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user