mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
Adding support for audio data recording and broadcasting for LeKiwi
This commit is contained in:
@@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.datasets.utils import load_audio, load_image_as_numpy
|
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||||
|
|
||||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||||
|
|
||||||
@@ -245,13 +245,18 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def sample_audio(audio_path: str) -> np.ndarray:
|
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||||
data = load_audio(audio_path)
|
data = load_audio_from_path(audio_path)
|
||||||
sampled_indices = sample_indices(len(data))
|
sampled_indices = sample_indices(len(data))
|
||||||
|
|
||||||
return data[sampled_indices]
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||||
|
sampled_indices = sample_indices(len(data))
|
||||||
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
def _reshape_stats_by_axis(
|
def _reshape_stats_by_axis(
|
||||||
stats: dict[str, np.ndarray],
|
stats: dict[str, np.ndarray],
|
||||||
axis: int | tuple[int, ...] | None,
|
axis: int | tuple[int, ...] | None,
|
||||||
@@ -520,7 +525,10 @@ def compute_episode_stats(
|
|||||||
axes_to_reduce = (0, 2, 3)
|
axes_to_reduce = (0, 2, 3)
|
||||||
keepdims = True
|
keepdims = True
|
||||||
elif features[key]["dtype"] == "audio":
|
elif features[key]["dtype"] == "audio":
|
||||||
ep_ft_array = sample_audio(data[0])
|
try:
|
||||||
|
ep_ft_array = sample_audio_from_path(data[0])
|
||||||
|
except TypeError: # Should only be triggered for LeKiwi robot
|
||||||
|
ep_ft_array = sample_audio_from_data(data)
|
||||||
axes_to_reduce = 0
|
axes_to_reduce = 0
|
||||||
keepdims = True
|
keepdims = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1279,11 +1279,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self._save_image(frame[key], img_path, compress_level)
|
self._save_image(frame[key], img_path, compress_level)
|
||||||
self.episode_buffer[key].append(str(img_path))
|
self.episode_buffer[key].append(str(img_path))
|
||||||
elif self.features[key]["dtype"] == "audio":
|
elif self.features[key]["dtype"] == "audio":
|
||||||
if frame_index == 0:
|
if self.meta.robot_type == "lekiwi":
|
||||||
audio_path = self._get_raw_audio_file_path(
|
self.episode_buffer[key].append(frame[key])
|
||||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
else:
|
||||||
)
|
if frame_index == 0:
|
||||||
self.episode_buffer[key].append(str(audio_path))
|
audio_path = self._get_raw_audio_file_path(
|
||||||
|
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||||
|
)
|
||||||
|
self.episode_buffer[key].append(str(audio_path))
|
||||||
else:
|
else:
|
||||||
self.episode_buffer[key].append(frame[key])
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
|
||||||
@@ -1347,7 +1350,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
# index, episode_index, task_index are already processed above, and image and video
|
# index, episode_index, task_index are already processed above, and image and video
|
||||||
# are processed separately by storing image path and frame info as meta data
|
# are processed separately by storing image path and frame info as meta data
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
|
continue
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
if self.meta.robot_type == "lekiwi":
|
||||||
|
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
|
|||||||
@@ -415,7 +415,7 @@ def load_image_as_numpy(
|
|||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
|
|
||||||
def load_audio(fpath: str | Path) -> np.ndarray:
|
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||||
audio_data, _ = read(fpath, dtype="float32")
|
audio_data, _ = read(fpath, dtype="float32")
|
||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -58,8 +59,9 @@ class LeKiwiClient(Robot):
|
|||||||
self.zmq_observation_socket = None
|
self.zmq_observation_socket = None
|
||||||
|
|
||||||
self.last_frames = {}
|
self.last_frames = {}
|
||||||
|
|
||||||
self.last_remote_state = {}
|
self.last_remote_state = {}
|
||||||
|
self.last_frame_timestamp = None
|
||||||
|
self.last_frame_delay = 0.0
|
||||||
|
|
||||||
# Define three speed levels and a current index
|
# Define three speed levels and a current index
|
||||||
self.speed_levels = [
|
self.speed_levels = [
|
||||||
@@ -139,6 +141,7 @@ class LeKiwiClient(Robot):
|
|||||||
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||||
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
||||||
|
|
||||||
|
self.last_frame_timestamp = perf_counter()
|
||||||
self._is_connected = True
|
self._is_connected = True
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
@@ -171,6 +174,8 @@ class LeKiwiClient(Robot):
|
|||||||
if last_msg is None:
|
if last_msg is None:
|
||||||
logging.warning("Poller indicated data, but failed to retrieve message.")
|
logging.warning("Poller indicated data, but failed to retrieve message.")
|
||||||
|
|
||||||
|
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
|
||||||
|
self.last_frame_timestamp = perf_counter()
|
||||||
return last_msg
|
return last_msg
|
||||||
|
|
||||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||||
@@ -207,14 +212,16 @@ class LeKiwiClient(Robot):
|
|||||||
|
|
||||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||||
|
|
||||||
# Decode images
|
# Decode images and audio data
|
||||||
current_frames: dict[str, np.ndarray] = {}
|
current_frames: dict[str, np.ndarray] = {}
|
||||||
for cam_name, image_b64 in observation.items():
|
for frame_name, frame_data in observation.items():
|
||||||
if cam_name not in self._cameras_ft:
|
if frame_name in self._cameras_ft:
|
||||||
continue
|
image = self._decode_image_from_b64(frame_data)
|
||||||
frame = self._decode_image_from_b64(image_b64)
|
if image is not None:
|
||||||
if frame is not None:
|
current_frames[frame_name] = image
|
||||||
current_frames[cam_name] = frame
|
elif frame_name in self._microphones_ft:
|
||||||
|
if frame_data is not None:
|
||||||
|
current_frames[frame_name] = frame_data
|
||||||
|
|
||||||
return current_frames, obs_dict
|
return current_frames, obs_dict
|
||||||
|
|
||||||
@@ -258,17 +265,27 @@ class LeKiwiClient(Robot):
|
|||||||
"""
|
"""
|
||||||
Capture observations from the remote robot: current follower arm positions,
|
Capture observations from the remote robot: current follower arm positions,
|
||||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frames, obs_dict = self._get_data()
|
frames, obs_dict = self._get_data()
|
||||||
|
|
||||||
# Loop over each configured camera
|
# Loop over each configured camera and microphone
|
||||||
for cam_name, frame in frames.items():
|
for frame_name, frame_data in frames.items():
|
||||||
if frame is None:
|
if frame_data is None:
|
||||||
logging.warning("Frame is None")
|
if frame_name in self._cameras_ft:
|
||||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
logging.warning("Image frame is None")
|
||||||
obs_dict[cam_name] = frame
|
image = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||||
|
obs_dict[frame_name] = image
|
||||||
|
elif frame_name in self._microphones_ft:
|
||||||
|
logging.warning("Audio frame is None")
|
||||||
|
obs_dict[frame_name] = np.zeros(
|
||||||
|
(
|
||||||
|
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
|
||||||
|
self._microphones_ft[frame_name][1],
|
||||||
|
),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ def record_loop(
|
|||||||
preprocessor.reset()
|
preprocessor.reset()
|
||||||
postprocessor.reset()
|
postprocessor.reset()
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None and robot.name != "lekiwi":
|
||||||
for microphone_key, microphone in robot.microphones.items():
|
for microphone_key, microphone in robot.microphones.items():
|
||||||
dataset.add_microphone_recording(microphone, microphone_key)
|
dataset.add_microphone_recording(microphone, microphone_key)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ from lerobot.datasets.compute_stats import (
|
|||||||
compute_episode_stats,
|
compute_episode_stats,
|
||||||
estimate_num_samples,
|
estimate_num_samples,
|
||||||
get_feature_stats,
|
get_feature_stats,
|
||||||
sample_audio,
|
sample_audio_from_data,
|
||||||
|
sample_audio_from_path,
|
||||||
sample_images,
|
sample_images,
|
||||||
sample_indices,
|
sample_indices,
|
||||||
)
|
)
|
||||||
@@ -78,10 +79,19 @@ def test_sample_images(mock_load):
|
|||||||
assert len(images) == estimate_num_samples(100)
|
assert len(images) == estimate_num_samples(100)
|
||||||
|
|
||||||
|
|
||||||
@patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
|
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||||
def test_sample_audio(mock_load):
|
def test_sample_audio_from_path(mock_load):
|
||||||
audio_path = "audio.wav"
|
audio_path = "audio.wav"
|
||||||
audio_samples = sample_audio(audio_path)
|
audio_samples = sample_audio_from_path(audio_path)
|
||||||
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
|
assert audio_samples.shape[1] == 2
|
||||||
|
assert audio_samples.dtype == np.float32
|
||||||
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_audio_from_data(mock_load):
|
||||||
|
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||||
|
audio_samples = sample_audio_from_data(audio_data)
|
||||||
assert isinstance(audio_samples, np.ndarray)
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
assert audio_samples.shape[1] == 2
|
assert audio_samples.shape[1] == 2
|
||||||
assert audio_samples.dtype == np.float32
|
assert audio_samples.dtype == np.float32
|
||||||
@@ -179,7 +189,7 @@ def test_compute_episode_stats():
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
||||||
patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio),
|
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||||
):
|
):
|
||||||
stats = compute_episode_stats(episode_data, features)
|
stats = compute_episode_stats(episode_data, features)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user