mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
94 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e6e54391bd | |||
| a201b33d20 | |||
| 9d42de328e | |||
| 8b9451b585 | |||
| ab4903e752 | |||
| 538cea6dbc | |||
| 5cd3572713 | |||
| 3399513e5e | |||
| 32fc4015ee | |||
| cc72c813bf | |||
| 606f31a86e | |||
| 4933c9dcc7 | |||
| 7e25385024 | |||
| cc70bff74d | |||
| 9f50913b9c | |||
| 4eb7694d47 | |||
| edb5559b5b | |||
| 552ec76195 | |||
| e75340b473 | |||
| 2a4c223ec7 | |||
| 1ee4d84f07 | |||
| 6bd40ca219 | |||
| b879cf3d04 | |||
| bd9e5c1a64 | |||
| 9271a0c900 | |||
| af2f044f5a | |||
| 0caba222ef | |||
| 6d73f5bfe6 | |||
| ef8f40c21b | |||
| 0232879245 | |||
| 2726b4e865 | |||
| e126d35249 | |||
| d7ae8cd699 | |||
| 2f96d8bf76 | |||
| e129c71b4f | |||
| a02d70389d | |||
| 0d4922ce49 | |||
| eaeff78924 | |||
| e2f3982e2c | |||
| a73ac2bdbb | |||
| 95de732e55 | |||
| b2383236ca | |||
| 4b98cc25c8 | |||
| 90780c4de8 | |||
| 6f6e046c53 | |||
| 8cd64eaad1 | |||
| e620395416 | |||
| 0fbcbcdb2e | |||
| 674f5dfd75 | |||
| 7d430c8067 | |||
| 5f114c1d74 | |||
| ad01ef19f4 | |||
| 59e8f4572c | |||
| 97e91698fb | |||
| af0294198a | |||
| 421fdcce96 | |||
| bb63ad9715 | |||
| 3c90a79c57 | |||
| 8e29c530ed | |||
| b573b7a052 | |||
| 926184110b | |||
| bf8ede852d | |||
| f73db4394b | |||
| bff91f9927 | |||
| 6d726266fd | |||
| 2962330bb1 | |||
| 067993bb11 | |||
| e4dd00c8f5 | |||
| e714ff22e2 | |||
| 3bbd161cfd | |||
| 6d7be63f59 | |||
| b9d0dfb9a2 | |||
| dce483060f | |||
| c32b9182d9 | |||
| a4d4ef0e7f | |||
| 9a5c96b2b1 | |||
| 0a6ca58299 | |||
| 688195fc46 | |||
| 99eb0bbafc | |||
| 16de8b3f19 | |||
| 580008663b | |||
| 52c424c5eb | |||
| 836195e59c | |||
| be09a59e05 | |||
| 373a169bd2 | |||
| 00536c6c5b | |||
| cdd3a859ef | |||
| 5276fc0d6f | |||
| 6a2882f978 | |||
| 8874547353 | |||
| 2864caad80 | |||
| d998660aa1 | |||
| 7e5f3b35e9 | |||
| 01fea7c407 |
@@ -0,0 +1,219 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from soundfile import read
|
||||||
|
|
||||||
|
from lerobot.microphones.configs import MicrophoneConfig
|
||||||
|
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
|
||||||
|
from lerobot.microphones.utils import (
|
||||||
|
async_microphones_start_recording,
|
||||||
|
async_microphones_stop_recording,
|
||||||
|
make_microphones_from_configs,
|
||||||
|
)
|
||||||
|
from lerobot.utils.robot_utils import (
|
||||||
|
precise_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
microphones_configs: dict[str, MicrophoneConfig],
|
||||||
|
audio_chunks_number: int,
|
||||||
|
audio_chunks_duration: float,
|
||||||
|
repetitions: int,
|
||||||
|
multiprocessing: bool = False,
|
||||||
|
):
|
||||||
|
recording_dir = Path("outputs/audio_benchmark")
|
||||||
|
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create microphones
|
||||||
|
microphones = make_microphones_from_configs(microphones_configs)
|
||||||
|
|
||||||
|
# Connect microphones
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
all_audio_chunks = []
|
||||||
|
for i in range(repetitions):
|
||||||
|
print(f"Repetition {i + 1}/{repetitions}...")
|
||||||
|
|
||||||
|
# Create audio chunks
|
||||||
|
audio_chunks = {}
|
||||||
|
for microphone_key in microphones:
|
||||||
|
audio_chunks.update({microphone_key: []})
|
||||||
|
|
||||||
|
# Start recording
|
||||||
|
async_microphones_start_recording(
|
||||||
|
microphones,
|
||||||
|
output_files=[
|
||||||
|
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
|
||||||
|
],
|
||||||
|
multiprocessing=multiprocessing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record audio chunks
|
||||||
|
for j in range(audio_chunks_number):
|
||||||
|
precise_sleep(audio_chunks_duration)
|
||||||
|
|
||||||
|
for microphone_key, microphone in microphones.items():
|
||||||
|
audio_chunk = microphone.read()
|
||||||
|
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
|
||||||
|
audio_chunks[microphone_key].append(audio_chunk)
|
||||||
|
|
||||||
|
# Stop recording
|
||||||
|
async_microphones_stop_recording(microphones)
|
||||||
|
|
||||||
|
for microphone_key in microphones:
|
||||||
|
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
|
||||||
|
|
||||||
|
all_audio_chunks.append(audio_chunks)
|
||||||
|
|
||||||
|
# Disconnect microphones
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
# Compute statistics
|
||||||
|
cmap = plt.get_cmap("tab10")
|
||||||
|
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
|
||||||
|
chunk_length = np.zeros((repetitions, len(microphones)))
|
||||||
|
record_length = np.zeros((repetitions, len(microphones)))
|
||||||
|
for i in range(repetitions):
|
||||||
|
for j, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||||
|
# Get recorded audio chunks
|
||||||
|
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
|
||||||
|
|
||||||
|
# Load recorded file
|
||||||
|
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
|
||||||
|
if recorded_data.ndim == 1:
|
||||||
|
recorded_data = np.expand_dims(recorded_data, axis=1)
|
||||||
|
|
||||||
|
record_length[i, j] = recorded_data.shape[0]
|
||||||
|
chunk_length[i, j] = recorded_audio_chunks.shape[0]
|
||||||
|
|
||||||
|
for k, (chunk_data, record_data) in enumerate(
|
||||||
|
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
|
||||||
|
):
|
||||||
|
# Plot audio chunks and recorded data
|
||||||
|
ax[i, j].plot(
|
||||||
|
np.arange(0, len(chunk_data)) / microphone.sample_rate,
|
||||||
|
chunk_data,
|
||||||
|
label=f"audio chunks - channel {k}",
|
||||||
|
color=cmap(2 * k),
|
||||||
|
)
|
||||||
|
ax[i, j].plot(
|
||||||
|
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||||
|
record_data,
|
||||||
|
label=f"recorded data - channel {k}",
|
||||||
|
linestyle="dashed",
|
||||||
|
color=cmap(2 * k + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot absolute difference (errors should be located at the end of the recordings)
|
||||||
|
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
|
||||||
|
chunk_data = np.append(
|
||||||
|
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
record_data = np.append(
|
||||||
|
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
|
||||||
|
)
|
||||||
|
ax[i, j].plot(
|
||||||
|
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||||
|
np.abs(chunk_data - record_data),
|
||||||
|
label=f"differences - channel {k}",
|
||||||
|
color="red",
|
||||||
|
linestyle="dotted",
|
||||||
|
)
|
||||||
|
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
|
||||||
|
ax[i, j].legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
differences = record_length - chunk_length
|
||||||
|
for i, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||||
|
print(
|
||||||
|
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||||
|
)
|
||||||
|
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
|
||||||
|
print(
|
||||||
|
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--microphones_indices",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--microphones_sample_rate",
|
||||||
|
type=float,
|
||||||
|
nargs="+",
|
||||||
|
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--microphones_channels",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--audio_chunks_number", type=int, default=2)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_chunks_duration",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repetitions",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multiprocessing",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
args["microphones_configs"] = {}
|
||||||
|
for index, sample_rate, channels in zip(
|
||||||
|
args["microphones_indices"],
|
||||||
|
args["microphones_sample_rate"],
|
||||||
|
args["microphones_channels"],
|
||||||
|
strict=False,
|
||||||
|
):
|
||||||
|
microphone_config = PortAudioMicrophoneConfig(
|
||||||
|
microphone_index=index,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
channels=channels,
|
||||||
|
)
|
||||||
|
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
|
||||||
|
args.pop("microphones_indices")
|
||||||
|
args.pop("microphones_sample_rate")
|
||||||
|
args.pop("microphones_channels")
|
||||||
|
|
||||||
|
main(**args)
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
from lerobot.microphones.configs import MicrophoneConfig
|
||||||
|
from lerobot.microphones.touchlab import TouchLabSensorConfig
|
||||||
|
from lerobot.microphones.utils import (
|
||||||
|
async_microphones_start_recording,
|
||||||
|
async_microphones_stop_recording,
|
||||||
|
make_microphones_from_configs,
|
||||||
|
)
|
||||||
|
from lerobot.utils.robot_utils import (
|
||||||
|
precise_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
sensors_configs: dict[str, MicrophoneConfig],
|
||||||
|
multiprocessing: bool = False,
|
||||||
|
):
|
||||||
|
recording_dir = Path("outputs/tactile_benchmark")
|
||||||
|
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create microphones
|
||||||
|
sensors = make_microphones_from_configs(sensors_configs)
|
||||||
|
|
||||||
|
# Connect microphones
|
||||||
|
for sensor in sensors.values():
|
||||||
|
sensor.connect()
|
||||||
|
|
||||||
|
# Create audio chunks
|
||||||
|
data_chunks = {}
|
||||||
|
for sensor_key in sensors:
|
||||||
|
data_chunks.update({sensor_key: []})
|
||||||
|
|
||||||
|
# Start recording
|
||||||
|
async_microphones_start_recording(
|
||||||
|
sensors,
|
||||||
|
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
|
||||||
|
multiprocessing=multiprocessing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record audio chunks
|
||||||
|
precise_sleep(10.0)
|
||||||
|
|
||||||
|
for sensor_key, sensor in sensors.items():
|
||||||
|
data_chunk = sensor.read()
|
||||||
|
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
|
||||||
|
data_chunks[sensor_key].append(data_chunk)
|
||||||
|
|
||||||
|
# Stop recording
|
||||||
|
async_microphones_stop_recording(sensors)
|
||||||
|
|
||||||
|
for sensor_key in sensors:
|
||||||
|
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
|
||||||
|
|
||||||
|
# Disconnect microphones
|
||||||
|
for sensor in sensors.values():
|
||||||
|
sensor.disconnect()
|
||||||
|
|
||||||
|
for sensor_key in sensors:
|
||||||
|
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
|
||||||
|
print(f"{sensor_key} - samples {data.shape[0]}")
|
||||||
|
print(f"{sensor_key} - sample rate {sample_rate}")
|
||||||
|
print(f"{sensor_key} - data {data}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--sensors_ports",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sensors_baud_rate",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sensors_sample_rate",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sensors_channels",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multiprocessing",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
args["sensors_configs"] = {}
|
||||||
|
for port, baud_rate, sample_rate, channels in zip(
|
||||||
|
args["sensors_ports"],
|
||||||
|
args["sensors_baud_rate"],
|
||||||
|
args["sensors_sample_rate"],
|
||||||
|
args["sensors_channels"],
|
||||||
|
strict=False,
|
||||||
|
):
|
||||||
|
if isinstance(channels, int):
|
||||||
|
channels = [channels]
|
||||||
|
sensor_config = TouchLabSensorConfig(
|
||||||
|
sensor_port=port,
|
||||||
|
baud_rate=baud_rate,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
channels=channels,
|
||||||
|
)
|
||||||
|
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
|
||||||
|
args.pop("sensors_ports")
|
||||||
|
args.pop("sensors_baud_rate")
|
||||||
|
args.pop("sensors_sample_rate")
|
||||||
|
args.pop("sensors_channels")
|
||||||
|
|
||||||
|
main(**args)
|
||||||
@@ -43,12 +43,13 @@ def main():
|
|||||||
keyboard.connect()
|
keyboard.connect()
|
||||||
|
|
||||||
# Init rerun viewer
|
# Init rerun viewer
|
||||||
init_rerun(session_name="lekiwi_teleop")
|
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
|
||||||
|
|
||||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||||
raise ValueError("Robot or teleop is not connected!")
|
raise ValueError("Robot or teleop is not connected!")
|
||||||
|
|
||||||
print("Starting teleop loop...")
|
print("Starting teleop loop...")
|
||||||
|
start = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
@@ -69,7 +70,7 @@ def main():
|
|||||||
_ = robot.send_action(action)
|
_ = robot.send_action(action)
|
||||||
|
|
||||||
# Visualize
|
# Visualize
|
||||||
log_rerun_data(observation=observation, action=action)
|
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
|
||||||
|
|
||||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||||
|
|
||||||
|
|||||||
@@ -90,12 +90,13 @@ def main():
|
|||||||
teleop_device.connect()
|
teleop_device.connect()
|
||||||
|
|
||||||
# Init rerun viewer
|
# Init rerun viewer
|
||||||
init_rerun(session_name="phone_so100_teleop")
|
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
|
||||||
|
|
||||||
if not robot.is_connected or not teleop_device.is_connected:
|
if not robot.is_connected or not teleop_device.is_connected:
|
||||||
raise ValueError("Robot or teleop is not connected!")
|
raise ValueError("Robot or teleop is not connected!")
|
||||||
|
|
||||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||||
|
start = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
@@ -112,7 +113,7 @@ def main():
|
|||||||
_ = robot.send_action(joint_action)
|
_ = robot.send_action(joint_action)
|
||||||
|
|
||||||
# Visualize
|
# Visualize
|
||||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
|
||||||
|
|
||||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||||
|
|
||||||
|
|||||||
@@ -95,9 +95,10 @@ def main():
|
|||||||
leader.connect()
|
leader.connect()
|
||||||
|
|
||||||
# Init rerun viewer
|
# Init rerun viewer
|
||||||
init_rerun(session_name="so100_so100_EE_teleop")
|
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
|
||||||
|
|
||||||
print("Starting teleop loop...")
|
print("Starting teleop loop...")
|
||||||
|
start = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
@@ -117,7 +118,9 @@ def main():
|
|||||||
_ = follower.send_action(follower_joints_act)
|
_ = follower.send_action(follower_joints_act)
|
||||||
|
|
||||||
# Visualize
|
# Visualize
|
||||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
log_rerun_data(
|
||||||
|
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
|
||||||
|
)
|
||||||
|
|
||||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||||
|
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
|||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||||
|
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||||
@@ -198,6 +199,7 @@ all = [
|
|||||||
"lerobot[xvla]",
|
"lerobot[xvla]",
|
||||||
"lerobot[hilserl]",
|
"lerobot[hilserl]",
|
||||||
"lerobot[async]",
|
"lerobot[async]",
|
||||||
|
"lerobot[audio]",
|
||||||
"lerobot[dev]",
|
"lerobot[dev]",
|
||||||
"lerobot[test]",
|
"lerobot[test]",
|
||||||
"lerobot[video_benchmark]",
|
"lerobot[video_benchmark]",
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ Example:
|
|||||||
print(lerobot.available_policies_per_env)
|
print(lerobot.available_policies_per_env)
|
||||||
print(lerobot.available_robots)
|
print(lerobot.available_robots)
|
||||||
print(lerobot.available_cameras)
|
print(lerobot.available_cameras)
|
||||||
|
print(lerobot.available_microphones)
|
||||||
print(lerobot.available_motors)
|
print(lerobot.available_motors)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -174,6 +175,12 @@ available_cameras = [
|
|||||||
"intelrealsense",
|
"intelrealsense",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# lists all available microphones from `lerobot/microphones`
|
||||||
|
available_microphones = [
|
||||||
|
"portaudio",
|
||||||
|
"touchlab",
|
||||||
|
]
|
||||||
|
|
||||||
# lists all available motors from `lerobot/motors`
|
# lists all available motors from `lerobot/motors`
|
||||||
available_motors = [
|
available_motors = [
|
||||||
"dynamixel",
|
"dynamixel",
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ import torch
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
|
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||||
|
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
return {}
|
return {}
|
||||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||||
|
if not self.input_features:
|
||||||
|
return {}
|
||||||
|
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_feature(self) -> PolicyFeature | None:
|
def action_feature(self) -> PolicyFeature | None:
|
||||||
if not self.output_features:
|
if not self.output_features:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from enum import Enum
|
|||||||
class FeatureType(str, Enum):
|
class FeatureType(str, Enum):
|
||||||
STATE = "STATE"
|
STATE = "STATE"
|
||||||
VISUAL = "VISUAL"
|
VISUAL = "VISUAL"
|
||||||
|
AUDIO = "AUDIO"
|
||||||
ENV = "ENV"
|
ENV = "ENV"
|
||||||
ACTION = "ACTION"
|
ACTION = "ACTION"
|
||||||
REWARD = "REWARD"
|
REWARD = "REWARD"
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ from lerobot.datasets.io_utils import (
|
|||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_AUDIO_PATH,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
@@ -43,7 +45,7 @@ from lerobot.datasets.utils import (
|
|||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||||
|
|
||||||
|
|
||||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
@@ -112,6 +114,7 @@ def update_meta_data(
|
|||||||
meta_idx,
|
meta_idx,
|
||||||
data_idx,
|
data_idx,
|
||||||
videos_idx,
|
videos_idx,
|
||||||
|
audios_idx,
|
||||||
):
|
):
|
||||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||||
|
|
||||||
@@ -127,7 +130,7 @@ def update_meta_data(
|
|||||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||||
data_idx: Dictionary containing current data chunk and file indices.
|
data_idx: Dictionary containing current data chunk and file indices.
|
||||||
videos_idx: Dictionary containing current video indices and timestamps.
|
videos_idx: Dictionary containing current video indices and timestamps.
|
||||||
|
audios_idx: Dictionary containing current audio indices and timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||||
"""
|
"""
|
||||||
@@ -225,6 +228,36 @@ def update_meta_data(
|
|||||||
# Clean up temporary columns
|
# Clean up temporary columns
|
||||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||||
|
|
||||||
|
for key, audio_idx in audios_idx.items():
|
||||||
|
# Store original audio file indices before updating
|
||||||
|
orig_chunk_col = f"audio/{key}/chunk_index"
|
||||||
|
orig_file_col = f"audio/{key}/file_index"
|
||||||
|
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||||
|
df["_orig_file"] = df[orig_file_col].copy()
|
||||||
|
|
||||||
|
# Update chunk and file indices to point to destination
|
||||||
|
df[orig_chunk_col] = audio_idx["chunk"]
|
||||||
|
df[orig_file_col] = audio_idx["file"]
|
||||||
|
|
||||||
|
# Apply per-source-file timestamp offsets
|
||||||
|
src_to_offset = audio_idx.get("src_to_offset", {})
|
||||||
|
if src_to_offset:
|
||||||
|
# Apply offset based on original source file
|
||||||
|
for idx in df.index:
|
||||||
|
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||||
|
offset = src_to_offset.get(src_key, 0)
|
||||||
|
df.at[idx, f"audio/{key}/from_timestamp"] += offset
|
||||||
|
df.at[idx, f"audio/{key}/to_timestamp"] += offset
|
||||||
|
else:
|
||||||
|
# Fallback to simple offset (for backward compatibility)
|
||||||
|
df[f"audio/{key}/from_timestamp"] = (
|
||||||
|
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
|
||||||
|
)
|
||||||
|
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
|
||||||
|
|
||||||
|
# Clean up temporary columns
|
||||||
|
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||||
|
|
||||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||||
@@ -239,6 +272,7 @@ def aggregate_datasets(
|
|||||||
aggr_root: Path | None = None,
|
aggr_root: Path | None = None,
|
||||||
data_files_size_in_mb: float | None = None,
|
data_files_size_in_mb: float | None = None,
|
||||||
video_files_size_in_mb: float | None = None,
|
video_files_size_in_mb: float | None = None,
|
||||||
|
audio_files_size_in_mb: float | None = None,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
):
|
):
|
||||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||||
@@ -256,6 +290,7 @@ def aggregate_datasets(
|
|||||||
aggr_root: Optional root path for the aggregated dataset.
|
aggr_root: Optional root path for the aggregated dataset.
|
||||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||||
|
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||||
"""
|
"""
|
||||||
logging.info("Start aggregate_datasets")
|
logging.info("Start aggregate_datasets")
|
||||||
@@ -264,6 +299,8 @@ def aggregate_datasets(
|
|||||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||||
if video_files_size_in_mb is None:
|
if video_files_size_in_mb is None:
|
||||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||||
|
if audio_files_size_in_mb is None:
|
||||||
|
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||||
if chunk_size is None:
|
if chunk_size is None:
|
||||||
chunk_size = DEFAULT_CHUNK_SIZE
|
chunk_size = DEFAULT_CHUNK_SIZE
|
||||||
|
|
||||||
@@ -276,6 +313,7 @@ def aggregate_datasets(
|
|||||||
)
|
)
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
|
||||||
|
|
||||||
dst_meta = LeRobotDatasetMetadata.create(
|
dst_meta = LeRobotDatasetMetadata.create(
|
||||||
repo_id=aggr_repo_id,
|
repo_id=aggr_repo_id,
|
||||||
@@ -287,6 +325,7 @@ def aggregate_datasets(
|
|||||||
chunks_size=chunk_size,
|
chunks_size=chunk_size,
|
||||||
data_files_size_in_mb=data_files_size_in_mb,
|
data_files_size_in_mb=data_files_size_in_mb,
|
||||||
video_files_size_in_mb=video_files_size_in_mb,
|
video_files_size_in_mb=video_files_size_in_mb,
|
||||||
|
audio_files_size_in_mb=audio_files_size_in_mb,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Find all tasks")
|
logging.info("Find all tasks")
|
||||||
@@ -300,14 +339,18 @@ def aggregate_datasets(
|
|||||||
videos_idx = {
|
videos_idx = {
|
||||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||||
}
|
}
|
||||||
|
audios_idx = {
|
||||||
|
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
|
||||||
|
}
|
||||||
|
|
||||||
dst_meta.episodes = {}
|
dst_meta.episodes = {}
|
||||||
|
|
||||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||||
|
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
|
||||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||||
|
|
||||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
|
||||||
|
|
||||||
# Clear the src_to_dst mapping after processing each source dataset
|
# Clear the src_to_dst mapping after processing each source dataset
|
||||||
# to avoid interference between different source datasets
|
# to avoid interference between different source datasets
|
||||||
@@ -375,7 +418,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
file_index=file_idx,
|
file_index=file_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
src_duration = get_video_duration_in_s(src_path)
|
src_duration = get_media_duration_in_s(src_path, media_type="video")
|
||||||
dst_key = (chunk_idx, file_idx)
|
dst_key = (chunk_idx, file_idx)
|
||||||
|
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
@@ -414,7 +457,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||||
concatenate_video_files(
|
concatenate_media_files(
|
||||||
[dst_path, src_path],
|
[dst_path, src_path],
|
||||||
dst_path,
|
dst_path,
|
||||||
)
|
)
|
||||||
@@ -429,6 +472,101 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
return videos_idx
|
return videos_idx
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
|
||||||
|
"""Aggregates audio files from a source dataset into the destination dataset.
|
||||||
|
|
||||||
|
Handles audio file concatenation and rotation based on file size limits.
|
||||||
|
Creates new audio files when size limits are exceeded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_meta: Source dataset metadata.
|
||||||
|
dst_meta: Destination dataset metadata.
|
||||||
|
audio_idx: Dictionary tracking audio chunk and file indices.
|
||||||
|
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||||
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Updated audio_idx with current chunk and file indices.
|
||||||
|
"""
|
||||||
|
for key in audios_idx:
|
||||||
|
audios_idx[key]["episode_duration"] = 0
|
||||||
|
# Track offset for each source (chunk, file) pair
|
||||||
|
audios_idx[key]["src_to_offset"] = {}
|
||||||
|
|
||||||
|
for key, audio_idx in audios_idx.items():
|
||||||
|
unique_chunk_file_pairs = {
|
||||||
|
(chunk, file)
|
||||||
|
for chunk, file in zip(
|
||||||
|
src_meta.episodes[f"audio/{key}/chunk_index"],
|
||||||
|
src_meta.episodes[f"audio/{key}/file_index"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||||
|
|
||||||
|
chunk_idx = audio_idx["chunk"]
|
||||||
|
file_idx = audio_idx["file"]
|
||||||
|
current_offset = audio_idx["latest_duration"]
|
||||||
|
|
||||||
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||||
|
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||||
|
audio_key=key,
|
||||||
|
chunk_index=src_chunk_idx,
|
||||||
|
file_index=src_file_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||||
|
audio_key=key,
|
||||||
|
chunk_index=chunk_idx,
|
||||||
|
file_index=file_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
src_duration = get_media_duration_in_s(src_path, media_type="audio")
|
||||||
|
|
||||||
|
if not dst_path.exists():
|
||||||
|
# Store offset before incrementing
|
||||||
|
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
|
audios_idx[key]["episode_duration"] += src_duration
|
||||||
|
current_offset += src_duration
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check file sizes before appending
|
||||||
|
src_size = get_file_size_in_mb(src_path)
|
||||||
|
dst_size = get_file_size_in_mb(dst_path)
|
||||||
|
|
||||||
|
if dst_size + src_size >= audio_files_size_in_mb:
|
||||||
|
# Rotate to a new file, this source becomes start of new destination
|
||||||
|
# So its offset should be 0
|
||||||
|
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||||
|
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||||
|
audio_key=key,
|
||||||
|
chunk_index=chunk_idx,
|
||||||
|
file_index=file_idx,
|
||||||
|
)
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
|
# Reset offset for next file
|
||||||
|
current_offset = src_duration
|
||||||
|
else:
|
||||||
|
# Append to existing video file - use current accumulated offset
|
||||||
|
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||||
|
concatenate_media_files(
|
||||||
|
[dst_path, src_path],
|
||||||
|
dst_path,
|
||||||
|
)
|
||||||
|
current_offset += src_duration
|
||||||
|
|
||||||
|
audios_idx[key]["episode_duration"] += src_duration
|
||||||
|
|
||||||
|
audios_idx[key]["chunk"] = chunk_idx
|
||||||
|
audios_idx[key]["file"] = file_idx
|
||||||
|
|
||||||
|
return audios_idx
|
||||||
|
|
||||||
|
|
||||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||||
|
|
||||||
@@ -501,7 +639,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
return data_idx
|
return data_idx
|
||||||
|
|
||||||
|
|
||||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
|
||||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||||
|
|
||||||
Reads source metadata files, updates all indices and timestamps,
|
Reads source metadata files, updates all indices and timestamps,
|
||||||
@@ -513,6 +651,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||||
data_idx: Dictionary tracking data chunk and file indices.
|
data_idx: Dictionary tracking data chunk and file indices.
|
||||||
videos_idx: Dictionary tracking video indices and timestamps.
|
videos_idx: Dictionary tracking video indices and timestamps.
|
||||||
|
audios_idx: Dictionary tracking audio indices and timestamps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated meta_idx with current chunk and file indices.
|
dict: Updated meta_idx with current chunk and file indices.
|
||||||
@@ -536,6 +675,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
meta_idx,
|
meta_idx,
|
||||||
data_idx,
|
data_idx,
|
||||||
videos_idx,
|
videos_idx,
|
||||||
|
audios_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
meta_idx, _ = append_or_create_parquet_file(
|
meta_idx, _ = append_or_create_parquet_file(
|
||||||
@@ -552,7 +692,8 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
# Increment latest_duration by the total duration added from this source dataset
|
# Increment latest_duration by the total duration added from this source dataset
|
||||||
for k in videos_idx:
|
for k in videos_idx:
|
||||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||||
|
for k in audios_idx:
|
||||||
|
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
|
||||||
return meta_idx
|
return meta_idx
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,275 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import torchcodec
|
||||||
|
from numpy import ceil
|
||||||
|
|
||||||
|
CHANNELS_LAYOUTS_MAPPING = {
|
||||||
|
1: "mono",
|
||||||
|
2: "stereo",
|
||||||
|
3: "2.1",
|
||||||
|
4: "3.1",
|
||||||
|
5: "4.1",
|
||||||
|
6: "5.1",
|
||||||
|
7: "6.1",
|
||||||
|
8: "7.1",
|
||||||
|
16: "hexadecagonal",
|
||||||
|
24: "22.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
start_time_s: float | None = 0.0,
|
||||||
|
backend: str | None = "torchcodec",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes audio using the specified backend.
|
||||||
|
Args:
|
||||||
|
audio_path (Path): Path to the audio file.
|
||||||
|
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||||
|
duration (float): Duration of the audio chunks in seconds.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded audio chunks.
|
||||||
|
|
||||||
|
Currently supports torchaudio.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
|
||||||
|
elif backend == "torchaudio":
|
||||||
|
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio_torchcodec(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
start_time_s: float | None = 0.0,
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO(CarolinePascal) : add channels selection
|
||||||
|
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
|
||||||
|
audio_sample_rate = audio_decoder.metadata.sample_rate
|
||||||
|
audio_channels = audio_decoder.metadata.num_channels
|
||||||
|
# TODO(CarolinePascal) : assert ts < total record duration
|
||||||
|
|
||||||
|
audio_chunks = []
|
||||||
|
timestamps = [
|
||||||
|
timestamp + start_time_s for timestamp in timestamps
|
||||||
|
] # Add an offset of start_time_s to each timestamp
|
||||||
|
for ts in timestamps:
|
||||||
|
current_audio_chunk = audio_decoder.get_samples_played_in_range(
|
||||||
|
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
||||||
|
)
|
||||||
|
|
||||||
|
current_audio_chunk_data = current_audio_chunk.data
|
||||||
|
|
||||||
|
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||||
|
if ts - duration < 0:
|
||||||
|
# No useful audio sample has been recorded
|
||||||
|
if ts < 1 / audio_sample_rate:
|
||||||
|
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||||
|
current_audio_chunk_data = torch.zeros(
|
||||||
|
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||||
|
)
|
||||||
|
# At least one useful audio sample has been recorded
|
||||||
|
else:
|
||||||
|
# Pad the beginning of the audio chunk with zeros
|
||||||
|
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||||
|
current_audio_chunk_data = torch.nn.functional.pad(
|
||||||
|
current_audio_chunk_data,
|
||||||
|
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(
|
||||||
|
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio_chunk_data)
|
||||||
|
|
||||||
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
|
assert len(timestamps) == len(audio_chunks)
|
||||||
|
return audio_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio_torchaudio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
start_time_s: float | None = 0.0,
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO(CarolinePascal) : add channels selection
|
||||||
|
audio_path = str(audio_path)
|
||||||
|
|
||||||
|
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||||
|
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||||
|
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
|
||||||
|
# TODO(CarolinePascal) : assert ts < total record duration
|
||||||
|
|
||||||
|
# TODO(CarolinePascal) : sort timestamps ?
|
||||||
|
|
||||||
|
reader.add_basic_audio_stream(
|
||||||
|
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||||
|
buffer_chunk_size=-1, # No dropping frames
|
||||||
|
format="fltp", # Format as float32
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks = []
|
||||||
|
timestamps = [
|
||||||
|
timestamp + start_time_s for timestamp in timestamps
|
||||||
|
] # Add an offset of start_time_s to each timestamp
|
||||||
|
for ts in timestamps:
|
||||||
|
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
|
||||||
|
status = reader.fill_buffer()
|
||||||
|
if status != 0:
|
||||||
|
# Should not happen, but just in case
|
||||||
|
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||||
|
|
||||||
|
current_audio_chunk = reader.pop_chunks()[0]
|
||||||
|
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
|
||||||
|
|
||||||
|
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||||
|
if ts - duration < 0:
|
||||||
|
# No useful audio sample has been recorded
|
||||||
|
if ts < 1 / audio_sample_rate:
|
||||||
|
current_audio_chunk_data = torch.zeros(
|
||||||
|
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||||
|
)
|
||||||
|
# At least one useful audio sample has been recorded
|
||||||
|
else:
|
||||||
|
# Remove the superfluous last samples of the audio chunk
|
||||||
|
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
|
||||||
|
# Pad the beginning of the audio chunk with zeros
|
||||||
|
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||||
|
current_audio_chunk_data = torch.nn.functional.pad(
|
||||||
|
current_audio_chunk_data,
|
||||||
|
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(
|
||||||
|
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio_chunk_data)
|
||||||
|
|
||||||
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
|
assert len(timestamps) == len(audio_chunks)
|
||||||
|
return audio_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def encode_audio(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||||
|
bit_rate: int | None = None,
|
||||||
|
sample_rate: int | None = None,
|
||||||
|
log_level: int | None = av.logging.ERROR,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Encodes an audio file using ffmpeg."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||||
|
|
||||||
|
# Set logging level
|
||||||
|
if log_level is not None:
|
||||||
|
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||||
|
logging.getLogger("libav").setLevel(log_level)
|
||||||
|
|
||||||
|
# Open input file
|
||||||
|
with av.open(str(input_path), "r") as input:
|
||||||
|
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
|
||||||
|
|
||||||
|
# Define sub-sampling options
|
||||||
|
if sample_rate is None:
|
||||||
|
sample_rate = input_stream.rate
|
||||||
|
|
||||||
|
# Create and open output file (overwrite by default)
|
||||||
|
with av.open(str(output_path), "w") as output:
|
||||||
|
output_stream = output.add_stream(
|
||||||
|
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
|
||||||
|
)
|
||||||
|
|
||||||
|
if bit_rate is not None:
|
||||||
|
output_stream.bit_rate = bit_rate
|
||||||
|
|
||||||
|
# Loop through input WAV packets and encode them
|
||||||
|
for input_frame in input.decode(
|
||||||
|
input_stream
|
||||||
|
): # This step handles both demuxing and decoding under the hood
|
||||||
|
packet = output_stream.encode(input_frame)
|
||||||
|
if packet:
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush the encoder
|
||||||
|
packet = output_stream.encode()
|
||||||
|
if packet:
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Reset logging level
|
||||||
|
if log_level is not None:
|
||||||
|
av.logging.restore_default_callback()
|
||||||
|
|
||||||
|
if not output_path.exists():
|
||||||
|
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_info(video_path: Path | str) -> dict:
|
||||||
|
# Set logging level
|
||||||
|
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||||
|
|
||||||
|
# Getting audio stream information
|
||||||
|
audio_info = {}
|
||||||
|
with av.open(str(video_path), "r") as audio_file:
|
||||||
|
try:
|
||||||
|
audio_stream = audio_file.streams.audio[0]
|
||||||
|
except IndexError:
|
||||||
|
# Reset logging level
|
||||||
|
av.logging.restore_default_callback()
|
||||||
|
return {"has_audio": False}
|
||||||
|
|
||||||
|
audio_info["audio.channels"] = audio_stream.channels
|
||||||
|
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||||
|
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||||
|
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||||
|
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||||
|
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||||
|
# In an ideal loseless case : fixed number of bits per sample.
|
||||||
|
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||||
|
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||||
|
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||||
|
audio_info["has_audio"] = True
|
||||||
|
|
||||||
|
# Reset logging level
|
||||||
|
av.logging.restore_default_callback()
|
||||||
|
|
||||||
|
return audio_info
|
||||||
@@ -19,8 +19,7 @@ import logging
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.datasets.io_utils import load_image_as_numpy
|
from lerobot.datasets.io_utils import load_audio_from_path, load_image_as_numpy
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
|
||||||
|
|
||||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||||
|
|
||||||
@@ -250,6 +249,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||||
|
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||||
|
data = load_audio_from_path(audio_path)
|
||||||
|
sampled_indices = sample_indices(len(data))
|
||||||
|
|
||||||
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||||
|
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||||
|
sampled_indices = sample_indices(len(data))
|
||||||
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
def _reshape_stats_by_axis(
|
def _reshape_stats_by_axis(
|
||||||
stats: dict[str, np.ndarray],
|
stats: dict[str, np.ndarray],
|
||||||
axis: int | tuple[int, ...] | None,
|
axis: int | tuple[int, ...] | None,
|
||||||
@@ -517,6 +530,13 @@ def compute_episode_stats(
|
|||||||
ep_ft_array = sample_images(data)
|
ep_ft_array = sample_images(data)
|
||||||
axes_to_reduce = (0, 2, 3)
|
axes_to_reduce = (0, 2, 3)
|
||||||
keepdims = True
|
keepdims = True
|
||||||
|
elif features[key]["dtype"] == "audio":
|
||||||
|
try:
|
||||||
|
ep_ft_array = sample_audio_from_path(data[0])
|
||||||
|
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
ep_ft_array = sample_audio_from_data(data)
|
||||||
|
axes_to_reduce = 0
|
||||||
|
keepdims = True
|
||||||
else:
|
else:
|
||||||
ep_ft_array = data
|
ep_ft_array = data
|
||||||
axes_to_reduce = 0
|
axes_to_reduce = 0
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import pyarrow as pa
|
|||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from lerobot.datasets.audio_utils import get_audio_info
|
||||||
from lerobot.datasets.compute_stats import aggregate_stats
|
from lerobot.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
|
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
|
||||||
from lerobot.datasets.io_utils import (
|
from lerobot.datasets.io_utils import (
|
||||||
@@ -40,6 +41,7 @@ from lerobot.datasets.io_utils import (
|
|||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
|
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
@@ -269,6 +271,32 @@ class LeRobotDatasetMetadata:
|
|||||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
|
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
|
||||||
|
"""Return the relative audio file path for the given episode and audio key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ep_index: Zero-based episode index.
|
||||||
|
audio_key: Feature key identifying the audio stream
|
||||||
|
(e.g. ``'observation.audio.microphone'``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the audio file containing this episode's audio.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexError: If ``ep_index`` is out of range.
|
||||||
|
"""
|
||||||
|
if self.episodes is None:
|
||||||
|
self.episodes = load_episodes(self.root)
|
||||||
|
if ep_index >= len(self.episodes):
|
||||||
|
raise IndexError(
|
||||||
|
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||||
|
)
|
||||||
|
ep = self.episodes[ep_index]
|
||||||
|
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
|
||||||
|
file_idx = ep[f"audio/{audio_key}/file_index"]
|
||||||
|
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
return Path(fpath)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
"""Formattable string for the parquet files."""
|
"""Formattable string for the parquet files."""
|
||||||
@@ -279,6 +307,11 @@ class LeRobotDatasetMetadata:
|
|||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info["video_path"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_path(self) -> str | None:
|
||||||
|
"""Formattable string for the audio files."""
|
||||||
|
return self.info["audio_path"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def robot_type(self) -> str | None:
|
def robot_type(self) -> str | None:
|
||||||
"""Robot type used in recording this dataset."""
|
"""Robot type used in recording this dataset."""
|
||||||
@@ -309,6 +342,11 @@ class LeRobotDatasetMetadata:
|
|||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_keys(self) -> list[str]:
|
||||||
|
"""Keys to access audio modalities."""
|
||||||
|
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[str, list | dict]:
|
def names(self) -> dict[str, list | dict]:
|
||||||
"""Names of the various dimensions of vector modalities."""
|
"""Names of the various dimensions of vector modalities."""
|
||||||
@@ -349,6 +387,11 @@ class LeRobotDatasetMetadata:
|
|||||||
"""Max size of video file in mega bytes."""
|
"""Max size of video file in mega bytes."""
|
||||||
return self.info["video_files_size_in_mb"]
|
return self.info["video_files_size_in_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_files_size_in_mb(self) -> int:
|
||||||
|
"""Max size of audio file in mega bytes."""
|
||||||
|
return self.info["audio_files_size_in_mb"]
|
||||||
|
|
||||||
def get_task_index(self, task: str) -> int | None:
|
def get_task_index(self, task: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||||
@@ -515,11 +558,27 @@ class LeRobotDatasetMetadata:
|
|||||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
|
def update_audio_info(self, audio_key: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||||
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
|
"""
|
||||||
|
if audio_key is not None and audio_key not in self.audio_keys:
|
||||||
|
raise ValueError(f"Audio key {audio_key} not found in dataset")
|
||||||
|
|
||||||
|
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
|
||||||
|
for key in audio_keys:
|
||||||
|
if not self.features[key].get("info", None):
|
||||||
|
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
|
||||||
|
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||||
|
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
|
||||||
|
|
||||||
def update_chunk_settings(
|
def update_chunk_settings(
|
||||||
self,
|
self,
|
||||||
chunks_size: int | None = None,
|
chunks_size: int | None = None,
|
||||||
data_files_size_in_mb: int | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: int | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
|
audio_files_size_in_mb: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update chunk and file size settings after dataset creation.
|
"""Update chunk and file size settings after dataset creation.
|
||||||
|
|
||||||
@@ -531,6 +590,7 @@ class LeRobotDatasetMetadata:
|
|||||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||||
|
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
|
||||||
"""
|
"""
|
||||||
if chunks_size is not None:
|
if chunks_size is not None:
|
||||||
if chunks_size <= 0:
|
if chunks_size <= 0:
|
||||||
@@ -547,6 +607,11 @@ class LeRobotDatasetMetadata:
|
|||||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||||
|
|
||||||
|
if audio_files_size_in_mb is not None:
|
||||||
|
if audio_files_size_in_mb <= 0:
|
||||||
|
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
|
||||||
|
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
|
||||||
|
|
||||||
# Update the info file on disk
|
# Update the info file on disk
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
@@ -554,12 +619,13 @@ class LeRobotDatasetMetadata:
|
|||||||
"""Get current chunk and file size settings.
|
"""Get current chunk and file size settings.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"chunks_size": self.chunks_size,
|
"chunks_size": self.chunks_size,
|
||||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||||
|
"audio_files_size_in_mb": self.audio_files_size_in_mb,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -586,6 +652,7 @@ class LeRobotDatasetMetadata:
|
|||||||
chunks_size: int | None = None,
|
chunks_size: int | None = None,
|
||||||
data_files_size_in_mb: int | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: int | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
|
audio_files_size_in_mb: int | None = None,
|
||||||
) -> "LeRobotDatasetMetadata":
|
) -> "LeRobotDatasetMetadata":
|
||||||
"""Create metadata for a new LeRobot dataset from scratch.
|
"""Create metadata for a new LeRobot dataset from scratch.
|
||||||
|
|
||||||
@@ -636,6 +703,7 @@ class LeRobotDatasetMetadata:
|
|||||||
chunks_size,
|
chunks_size,
|
||||||
data_files_size_in_mb,
|
data_files_size_in_mb,
|
||||||
video_files_size_in_mb,
|
video_files_size_in_mb,
|
||||||
|
audio_files_size_in_mb,
|
||||||
)
|
)
|
||||||
if len(obj.video_keys) > 0 and not use_videos:
|
if len(obj.video_keys) > 0 and not use_videos:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -21,12 +21,14 @@ from pathlib import Path
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.audio_utils import decode_audio
|
||||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.feature_utils import (
|
from lerobot.datasets.feature_utils import (
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
)
|
)
|
||||||
|
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||||
from lerobot.datasets.io_utils import (
|
from lerobot.datasets.io_utils import (
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_nested_dataset,
|
load_nested_dataset,
|
||||||
@@ -130,7 +132,7 @@ class DatasetReader:
|
|||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
def _check_cached_episodes_sufficient(self) -> bool:
|
def _check_cached_episodes_sufficient(self) -> bool:
|
||||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
|
||||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -154,6 +156,13 @@ class DatasetReader:
|
|||||||
if not video_path.exists():
|
if not video_path.exists():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if len(self._meta.audio_keys) > 0:
|
||||||
|
for ep_idx in requested_episodes:
|
||||||
|
for audio_key in self._meta.audio_keys:
|
||||||
|
audio_path = self.root / self._meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||||
|
if not audio_path.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_episodes_file_paths(self) -> list[Path]:
|
def get_episodes_file_paths(self) -> list[Path]:
|
||||||
@@ -170,6 +179,15 @@ class DatasetReader:
|
|||||||
for ep_idx in episodes
|
for ep_idx in episodes
|
||||||
]
|
]
|
||||||
fpaths += video_files
|
fpaths += video_files
|
||||||
|
|
||||||
|
if len(self._meta.audio_keys) > 0:
|
||||||
|
audio_files = [
|
||||||
|
str(self._meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||||
|
for audio_key in self._meta.audio_keys
|
||||||
|
for ep_idx in episodes
|
||||||
|
]
|
||||||
|
fpaths += audio_files
|
||||||
|
|
||||||
# episodes are stored in the same files, so we return unique paths only
|
# episodes are stored in the same files, so we return unique paths only
|
||||||
fpaths = list(set(fpaths))
|
fpaths = list(set(fpaths))
|
||||||
return fpaths
|
return fpaths
|
||||||
@@ -199,7 +217,7 @@ class DatasetReader:
|
|||||||
query_indices: dict[str, list[int]] | None = None,
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self._meta.video_keys:
|
for key in self._meta.video_keys + self._meta.audio_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
if self._absolute_to_relative_idx is not None:
|
if self._absolute_to_relative_idx is not None:
|
||||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||||
@@ -213,10 +231,10 @@ class DatasetReader:
|
|||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
"""Query dataset for indices across keys, skipping video keys."""
|
"""Query dataset for indices across keys, skipping video and audio keys."""
|
||||||
result: dict = {}
|
result: dict = {}
|
||||||
for key, q_idx in query_indices.items():
|
for key, q_idx in query_indices.items():
|
||||||
if key in self._meta.video_keys:
|
if key in self._meta.video_keys or key in self._meta.audio_keys:
|
||||||
continue
|
continue
|
||||||
relative_indices = (
|
relative_indices = (
|
||||||
q_idx
|
q_idx
|
||||||
@@ -246,6 +264,28 @@ class DatasetReader:
|
|||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
# TODO(CarolinePascal): add variable query durations
|
||||||
|
def _query_audio(
|
||||||
|
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
ep = self.meta.episodes[ep_idx]
|
||||||
|
item = {}
|
||||||
|
for audio_key, query_ts in query_timestamps.items():
|
||||||
|
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||||
|
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||||
|
# shift the query timestamp accordingly.
|
||||||
|
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
|
||||||
|
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||||
|
|
||||||
|
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||||
|
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
|
||||||
|
audio_chunk = decode_audio(
|
||||||
|
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
|
||||||
|
)
|
||||||
|
item[audio_key] = audio_chunk.squeeze(0)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
def get_item(self, idx) -> dict:
|
def get_item(self, idx) -> dict:
|
||||||
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
||||||
|
|
||||||
@@ -265,11 +305,12 @@ class DatasetReader:
|
|||||||
for key, val in query_result.items():
|
for key, val in query_result.items():
|
||||||
item[key] = val
|
item[key] = val
|
||||||
|
|
||||||
if len(self._meta.video_keys) > 0:
|
if len(self._meta.video_keys) > 0 or len(self._meta.audio_keys) > 0:
|
||||||
current_ts = item["timestamp"].item()
|
current_ts = item["timestamp"].item()
|
||||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||||
|
item = {**video_frames, **audio_chunks, **item}
|
||||||
|
|
||||||
if self._image_transforms is not None:
|
if self._image_transforms is not None:
|
||||||
image_keys = self._meta.camera_keys
|
image_keys = self._meta.camera_keys
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import PIL.Image
|
|||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.audio_utils import encode_audio
|
||||||
from lerobot.datasets.compute_stats import compute_episode_stats
|
from lerobot.datasets.compute_stats import compute_episode_stats
|
||||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.feature_utils import (
|
from lerobot.datasets.feature_utils import (
|
||||||
@@ -48,14 +49,17 @@ from lerobot.datasets.io_utils import (
|
|||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
|
DEFAULT_RAW_AUDIO_PATH,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import (
|
from lerobot.datasets.video_utils import (
|
||||||
StreamingVideoEncoder,
|
StreamingVideoEncoder,
|
||||||
concatenate_video_files,
|
concatenate_media_files,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
get_video_duration_in_s,
|
get_media_duration_in_s,
|
||||||
)
|
)
|
||||||
|
from lerobot.microphones.microphone import Microphone
|
||||||
|
from lerobot.microphones.utils import async_microphones_start_recording
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -144,6 +148,10 @@ class DatasetWriter:
|
|||||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||||
|
|
||||||
|
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
|
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||||
|
return self._root / fpath
|
||||||
|
|
||||||
def _save_image(
|
def _save_image(
|
||||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -208,11 +216,43 @@ class DatasetWriter:
|
|||||||
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
|
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
|
||||||
self._save_image(frame[key], img_path, compress_level)
|
self._save_image(frame[key], img_path, compress_level)
|
||||||
self.episode_buffer[key].append(str(img_path))
|
self.episode_buffer[key].append(str(img_path))
|
||||||
|
elif self._meta.features[key]["dtype"] == "audio":
|
||||||
|
if (
|
||||||
|
self._meta.robot_type == "lekiwi"
|
||||||
|
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||||
|
if frame_index == 0:
|
||||||
|
audio_path = self._get_raw_audio_file_path(
|
||||||
|
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||||
|
)
|
||||||
|
self.episode_buffer[key].append(str(audio_path))
|
||||||
else:
|
else:
|
||||||
self.episode_buffer[key].append(frame[key])
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
|
||||||
self.episode_buffer["size"] += 1
|
self.episode_buffer["size"] += 1
|
||||||
|
|
||||||
|
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
|
||||||
|
"""
|
||||||
|
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio_file = self._get_raw_audio_file_path(self._meta.total_episodes, "observation.audio." + microphone_key)
|
||||||
|
microphone.start_recording(output_file=audio_file)
|
||||||
|
|
||||||
|
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
|
||||||
|
"""
|
||||||
|
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_files = []
|
||||||
|
for microphone_key in microphones:
|
||||||
|
output_files.append(
|
||||||
|
self._get_raw_audio_file_path(self._meta.total_episodes, "observation.audio." + microphone_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
async_microphones_start_recording(microphones, output_files)
|
||||||
|
|
||||||
def save_episode(
|
def save_episode(
|
||||||
self,
|
self,
|
||||||
episode_data: dict | None = None,
|
episode_data: dict | None = None,
|
||||||
@@ -241,12 +281,19 @@ class DatasetWriter:
|
|||||||
for key, ft in self._meta.features.items():
|
for key, ft in self._meta.features.items():
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
continue
|
continue
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
if (
|
||||||
|
self._meta.robot_type == "lekiwi"
|
||||||
|
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||||
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
# Wait for image writer to end, so that episode stats over images can be computed
|
# Wait for image writer to end, so that episode stats over images can be computed
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
|
|
||||||
has_video_keys = len(self._meta.video_keys) > 0
|
has_video_keys = len(self._meta.video_keys) > 0
|
||||||
|
has_audio_keys = len(self._meta.audio_keys) > 0
|
||||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||||
use_batched_encoding = self._batch_encoding_size > 1
|
use_batched_encoding = self._batch_encoding_size > 1
|
||||||
|
|
||||||
@@ -273,7 +320,7 @@ class DatasetWriter:
|
|||||||
for k, v in video_stats.items()
|
for k, v in video_stats.items()
|
||||||
}
|
}
|
||||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||||
elif has_video_keys and not use_batched_encoding:
|
elif (has_video_keys or has_audio_keys) and not use_batched_encoding:
|
||||||
num_cameras = len(self._meta.video_keys)
|
num_cameras = len(self._meta.video_keys)
|
||||||
if parallel_encoding and num_cameras > 1:
|
if parallel_encoding and num_cameras > 1:
|
||||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||||
@@ -309,19 +356,28 @@ class DatasetWriter:
|
|||||||
for video_key in self._meta.video_keys:
|
for video_key in self._meta.video_keys:
|
||||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||||
|
|
||||||
|
# TODO(Caroline): add parallel encoding for audio as well
|
||||||
|
for audio_key in self._meta.audio_keys:
|
||||||
|
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
|
||||||
|
|
||||||
# `meta.save_episode` need to be executed after encoding the videos
|
# `meta.save_episode` need to be executed after encoding the videos
|
||||||
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||||
|
|
||||||
if has_video_keys and use_batched_encoding:
|
if (has_video_keys or has_audio_keys) and use_batched_encoding:
|
||||||
self._episodes_since_last_encoding += 1
|
self._episodes_since_last_encoding += 1
|
||||||
if self._episodes_since_last_encoding == self._batch_encoding_size:
|
if self._episodes_since_last_encoding == self._batch_encoding_size:
|
||||||
start_ep = self._meta.total_episodes - self._batch_encoding_size
|
start_ep = self._meta.total_episodes - self._batch_encoding_size
|
||||||
end_ep = self._meta.total_episodes
|
end_ep = self._meta.total_episodes
|
||||||
self._batch_save_episode_video(start_ep, end_ep)
|
if has_video_keys:
|
||||||
|
self._batch_save_episode_video(start_ep, end_ep)
|
||||||
|
if has_audio_keys:
|
||||||
|
self._batch_save_episode_audio(start_ep, end_ep)
|
||||||
self._episodes_since_last_encoding = 0
|
self._episodes_since_last_encoding = 0
|
||||||
|
|
||||||
if episode_data is None:
|
if episode_data is None:
|
||||||
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
|
self.clear_episode_buffer(
|
||||||
|
delete_images=len(self._meta.image_keys) > 0, delete_audio=len(self._meta.audio_keys) > 0
|
||||||
|
)
|
||||||
|
|
||||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||||
"""Batch save videos for multiple episodes."""
|
"""Batch save videos for multiple episodes."""
|
||||||
@@ -368,6 +424,59 @@ class DatasetWriter:
|
|||||||
episode_df.to_parquet(episode_df_path)
|
episode_df.to_parquet(episode_df_path)
|
||||||
self._meta.episodes = load_episodes(self._root)
|
self._meta.episodes = load_episodes(self._root)
|
||||||
|
|
||||||
|
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Batch save audio for multiple episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_episode: Starting episode index (inclusive)
|
||||||
|
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
|
||||||
|
"""
|
||||||
|
if end_episode is None:
|
||||||
|
end_episode = self._meta.total_episodes
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"]
|
||||||
|
file_idx = self._meta.episodes[start_episode]["data/file_index"]
|
||||||
|
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
episode_df = pd.read_parquet(episode_df_path)
|
||||||
|
|
||||||
|
for ep_idx in range(start_episode, end_episode):
|
||||||
|
logging.info(f"Encoding audio for episode {ep_idx}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||||
|
or self._meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||||
|
):
|
||||||
|
# The current episode is in a new chunk or file.
|
||||||
|
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
|
||||||
|
episode_df.to_parquet(episode_df_path)
|
||||||
|
self._meta.episodes = load_episodes(self._root)
|
||||||
|
|
||||||
|
# Load new episode dataframe
|
||||||
|
chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"]
|
||||||
|
file_idx = self._meta.episodes[ep_idx]["data/file_index"]
|
||||||
|
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
|
||||||
|
chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
episode_df = pd.read_parquet(episode_df_path)
|
||||||
|
|
||||||
|
# Save the current episode's video metadata to the dataframe
|
||||||
|
audio_ep_metadata = {}
|
||||||
|
for audio_key in self._meta.audio_keys:
|
||||||
|
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||||
|
audio_ep_metadata.pop("episode_index")
|
||||||
|
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||||
|
dtype_backend="pyarrow"
|
||||||
|
) # allows NaN values along with integers
|
||||||
|
|
||||||
|
episode_df = episode_df.combine_first(audio_ep_df)
|
||||||
|
episode_df.to_parquet(episode_df_path)
|
||||||
|
self._meta.episodes = load_episodes(self._root)
|
||||||
|
|
||||||
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||||
"""Save episode data to a parquet file."""
|
"""Save episode data to a parquet file."""
|
||||||
# Use metadata features as the authoritative schema
|
# Use metadata features as the authoritative schema
|
||||||
@@ -445,7 +554,7 @@ class DatasetWriter:
|
|||||||
ep_path = temp_path
|
ep_path = temp_path
|
||||||
|
|
||||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
episode_index == 0
|
episode_index == 0
|
||||||
@@ -485,7 +594,7 @@ class DatasetWriter:
|
|||||||
shutil.move(str(ep_path), str(new_path))
|
shutil.move(str(ep_path), str(new_path))
|
||||||
latest_duration_in_s = 0.0
|
latest_duration_in_s = 0.0
|
||||||
else:
|
else:
|
||||||
concatenate_video_files(
|
concatenate_media_files(
|
||||||
[latest_path, ep_path],
|
[latest_path, ep_path],
|
||||||
latest_path,
|
latest_path,
|
||||||
)
|
)
|
||||||
@@ -507,7 +616,91 @@ class DatasetWriter:
|
|||||||
}
|
}
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
|
||||||
|
"""
|
||||||
|
Use ffmpeg to convert raw audio files into m4a audio files.
|
||||||
|
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
|
since audio encoding with ffmpeg is already using multithreading.
|
||||||
|
"""
|
||||||
|
temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{audio_key}_{episode_index:03d}.m4a"
|
||||||
|
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
|
encode_audio(raw_audio_file, temp_path, overwrite=True)
|
||||||
|
raw_audio_file.unlink()
|
||||||
|
return temp_path
|
||||||
|
|
||||||
|
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
|
||||||
|
# Encode episode audio into a temporary audio file
|
||||||
|
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
|
||||||
|
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||||
|
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||||
|
|
||||||
|
if (
|
||||||
|
episode_index == 0
|
||||||
|
or self._meta.latest_episode is None
|
||||||
|
or f"audio/{audio_key}/chunk_index" not in self._meta.latest_episode
|
||||||
|
):
|
||||||
|
# Initialize indices for a new dataset made of the first episode data
|
||||||
|
chunk_idx, file_idx = 0, 0
|
||||||
|
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
|
||||||
|
# It means we are resuming recording, so we need to load the latest episode
|
||||||
|
# Update the indices to avoid overwriting the latest episode
|
||||||
|
old_chunk_idx = self._meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
|
||||||
|
old_file_idx = self._meta.episodes[-1][f"audio/{audio_key}/file_index"]
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(
|
||||||
|
old_chunk_idx, old_file_idx, self._meta.chunks_size
|
||||||
|
)
|
||||||
|
latest_duration_in_s = 0.0
|
||||||
|
new_path = self._root / self._meta.audio_path.format(
|
||||||
|
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(ep_path), str(new_path))
|
||||||
|
else:
|
||||||
|
# Retrieve information from the latest updated audio file using latest_episode
|
||||||
|
latest_ep = self._meta.latest_episode
|
||||||
|
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
|
||||||
|
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
|
||||||
|
|
||||||
|
latest_path = self._root / self._meta.audio_path.format(
|
||||||
|
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||||
|
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
|
||||||
|
|
||||||
|
if latest_size_in_mb + ep_size_in_mb >= self._meta.audio_files_size_in_mb:
|
||||||
|
# Move temporary episode audio to a new audio file in the dataset
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
|
||||||
|
new_path = self._root / self._meta.audio_path.format(
|
||||||
|
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(ep_path), str(new_path))
|
||||||
|
latest_duration_in_s = 0.0
|
||||||
|
else:
|
||||||
|
# Update latest audio file
|
||||||
|
concatenate_media_files(
|
||||||
|
[latest_path, ep_path],
|
||||||
|
latest_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove temporary directory
|
||||||
|
shutil.rmtree(str(ep_path.parent))
|
||||||
|
|
||||||
|
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
|
||||||
|
if episode_index == 0:
|
||||||
|
self._meta.update_audio_info(audio_key)
|
||||||
|
write_info(self._meta.info, self._meta.root) # ensure audio info always written properly
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"episode_index": episode_index,
|
||||||
|
f"audio/{audio_key}/chunk_index": chunk_idx,
|
||||||
|
f"audio/{audio_key}/file_index": file_idx,
|
||||||
|
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
|
||||||
|
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
|
||||||
"""Discard the current episode buffer and optionally delete temp images.
|
"""Discard the current episode buffer and optionally delete temp images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -531,6 +724,15 @@ class DatasetWriter:
|
|||||||
if img_dir.is_dir():
|
if img_dir.is_dir():
|
||||||
shutil.rmtree(img_dir)
|
shutil.rmtree(img_dir)
|
||||||
|
|
||||||
|
if delete_audio:
|
||||||
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
|
if isinstance(episode_index, np.ndarray):
|
||||||
|
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||||
|
for audio_key in self._meta.audio_keys:
|
||||||
|
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
|
if audio_file.is_file():
|
||||||
|
audio_file.unlink()
|
||||||
|
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
|
||||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||||
@@ -596,7 +798,7 @@ class DatasetWriter:
|
|||||||
self._streaming_encoder.cancel_episode()
|
self._streaming_encoder.cancel_episode()
|
||||||
|
|
||||||
def cleanup_interrupted_episode(self, episode_index: int) -> None:
|
def cleanup_interrupted_episode(self, episode_index: int) -> None:
|
||||||
"""Remove temporary image directories for an interrupted episode."""
|
"""Remove temporary image and audio directories for an interrupted episode."""
|
||||||
for key in self._meta.video_keys:
|
for key in self._meta.video_keys:
|
||||||
img_dir = self._get_image_file_path(
|
img_dir = self._get_image_file_path(
|
||||||
episode_index=episode_index, image_key=key, frame_index=0
|
episode_index=episode_index, image_key=key, frame_index=0
|
||||||
@@ -607,6 +809,14 @@ class DatasetWriter:
|
|||||||
)
|
)
|
||||||
shutil.rmtree(img_dir)
|
shutil.rmtree(img_dir)
|
||||||
|
|
||||||
|
for key in self._meta.audio_keys:
|
||||||
|
audio_file = self._get_raw_audio_file_path(episode_index=episode_index, audio_key=key)
|
||||||
|
if audio_file.exists():
|
||||||
|
logger.debug(
|
||||||
|
f"Cleaning up interrupted episode audio for episode {episode_index}, microphone {key}"
|
||||||
|
)
|
||||||
|
audio_file.unlink()
|
||||||
|
|
||||||
def finalize(self) -> None:
|
def finalize(self) -> None:
|
||||||
"""Flush all pending work and release all resources.
|
"""Flush all pending work and release all resources.
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from PIL import Image as PILImage
|
|||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_AUDIO_PATH,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
@@ -47,7 +49,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
|||||||
"""
|
"""
|
||||||
hf_features = {}
|
hf_features = {}
|
||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "video":
|
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||||
continue
|
continue
|
||||||
elif ft["dtype"] == "image":
|
elif ft["dtype"] == "image":
|
||||||
hf_features[key] = datasets.Image()
|
hf_features[key] = datasets.Image()
|
||||||
@@ -110,7 +112,12 @@ def hw_to_dataset_features(
|
|||||||
for key, ftype in hw_features.items()
|
for key, ftype in hw_features.items()
|
||||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||||
}
|
}
|
||||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
cam_fts = {
|
||||||
|
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
|
||||||
|
}
|
||||||
|
mic_fts = {
|
||||||
|
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
|
||||||
|
}
|
||||||
|
|
||||||
if joint_fts and prefix == ACTION:
|
if joint_fts and prefix == ACTION:
|
||||||
features[prefix] = {
|
features[prefix] = {
|
||||||
@@ -133,6 +140,14 @@ def hw_to_dataset_features(
|
|||||||
"names": ["height", "width", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for key, parameters in mic_fts.items():
|
||||||
|
features[f"{prefix}.audio.{key}"] = {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (len(parameters[1]),),
|
||||||
|
"names": ["channels"],
|
||||||
|
"info": {"sample_rate": parameters[0]},
|
||||||
|
}
|
||||||
|
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@@ -162,6 +177,8 @@ def build_dataset_frame(
|
|||||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
elif ft["dtype"] in ["image", "video"]:
|
||||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
@@ -195,6 +212,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
type = FeatureType.AUDIO
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||||
elif key == OBS_ENV_STATE:
|
elif key == OBS_ENV_STATE:
|
||||||
type = FeatureType.ENV
|
type = FeatureType.ENV
|
||||||
elif key.startswith(OBS_STR):
|
elif key.startswith(OBS_STR):
|
||||||
@@ -273,6 +294,7 @@ def create_empty_dataset_info(
|
|||||||
chunks_size: int | None = None,
|
chunks_size: int | None = None,
|
||||||
data_files_size_in_mb: int | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: int | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
|
audio_files_size_in_mb: int | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Create a template dictionary for a new dataset's `info.json`.
|
"""Create a template dictionary for a new dataset's `info.json`.
|
||||||
|
|
||||||
@@ -282,7 +304,10 @@ def create_empty_dataset_info(
|
|||||||
features (dict): The LeRobot features dictionary for the dataset.
|
features (dict): The LeRobot features dictionary for the dataset.
|
||||||
use_videos (bool): Whether the dataset will store videos.
|
use_videos (bool): Whether the dataset will store videos.
|
||||||
robot_type (str | None): The type of robot used, if any.
|
robot_type (str | None): The type of robot used, if any.
|
||||||
|
chunks_size (int | None): The number of files per chunk.
|
||||||
|
data_files_size_in_mb (int | None): The maximum size per data file in MB.
|
||||||
|
video_files_size_in_mb (int | None): The maximum size per video file in MB.
|
||||||
|
audio_files_size_in_mb (int | None): The maximum size per audio file in MB.
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary with the initial dataset metadata.
|
dict: A dictionary with the initial dataset metadata.
|
||||||
"""
|
"""
|
||||||
@@ -295,10 +320,12 @@ def create_empty_dataset_info(
|
|||||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
|
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": DEFAULT_DATA_PATH,
|
"data_path": DEFAULT_DATA_PATH,
|
||||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
|
"audio_path": DEFAULT_AUDIO_PATH,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,6 +462,8 @@ def validate_feature_dtype_and_shape(
|
|||||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||||
elif expected_dtype in ["image", "video"]:
|
elif expected_dtype in ["image", "video"]:
|
||||||
return validate_feature_image_or_video(name, expected_shape, value)
|
return validate_feature_image_or_video(name, expected_shape, value)
|
||||||
|
elif expected_dtype == "audio":
|
||||||
|
return validate_feature_audio(name, expected_shape, value)
|
||||||
elif expected_dtype == "string":
|
elif expected_dtype == "string":
|
||||||
return validate_feature_string(name, value)
|
return validate_feature_string(name, value)
|
||||||
else:
|
else:
|
||||||
@@ -501,6 +530,33 @@ def validate_feature_image_or_video(
|
|||||||
return error_message
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||||
|
"""Validate a feature that is expected to be an audio frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the feature.
|
||||||
|
expected_shape (list[str]): The expected shape (C,).
|
||||||
|
value: The audio data to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: An error message if validation fails, otherwise an empty string.
|
||||||
|
"""
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
actual_shape = value.shape
|
||||||
|
c = expected_shape
|
||||||
|
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
|
||||||
|
-1
|
||||||
|
]: # The number of frames might be different
|
||||||
|
error_message += (
|
||||||
|
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
def validate_feature_string(name: str, value: str) -> str:
|
def validate_feature_string(name: str, value: str) -> str:
|
||||||
"""Validate a feature that is expected to be a string.
|
"""Validate a feature that is expected to be a string.
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import pandas
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow.dataset as pa_ds
|
import pyarrow.dataset as pa_ds
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from datasets.table import embed_table_storage
|
from datasets.table import embed_table_storage
|
||||||
@@ -280,6 +281,24 @@ def load_image_as_numpy(
|
|||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||||
|
"""Load an audio file from a path into a numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fpath (str | Path): Path to the audio file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The audio as a numpy array.
|
||||||
|
"""
|
||||||
|
audio_data, _ = sf.read(fpath, dtype="float32")
|
||||||
|
|
||||||
|
# Fill missing channel dimension when loading mono audio data
|
||||||
|
if audio_data.ndim == 1:
|
||||||
|
audio_data = np.expand_dims(audio_data, axis=1)
|
||||||
|
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
|
download_audio: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
batch_encoding_size: int = 1,
|
batch_encoding_size: int = 1,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
streaming_encoding: bool = False,
|
streaming_encoding: bool = False,
|
||||||
@@ -91,6 +93,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
task-conditioned training.
|
task-conditioned training.
|
||||||
- data (backed by datasets.Dataset), which reads values from parquet files.
|
- data (backed by datasets.Dataset), which reads values from parquet files.
|
||||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||||
|
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
|
||||||
|
|
||||||
A typical LeRobotDataset looks like this from its root path:
|
A typical LeRobotDataset looks like this from its root path:
|
||||||
.
|
.
|
||||||
@@ -116,19 +119,37 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
│ ├── info.json
|
│ ├── info.json
|
||||||
│ ├── stats.json
|
│ ├── stats.json
|
||||||
│ └── tasks.parquet
|
│ └── tasks.parquet
|
||||||
└── videos
|
├── videos
|
||||||
├── observation.images.laptop
|
│ ├── observation.images.laptop
|
||||||
|
│ │ ├── chunk-000
|
||||||
|
│ │ │ ├── file-000.mp4
|
||||||
|
│ │ │ ├── file-001.mp4
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ │ ├── chunk-001
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ │ └── ...
|
||||||
|
│ ├── observation.images.phone
|
||||||
|
│ │ ├── chunk-000
|
||||||
|
│ │ │ ├── file-000.mp4
|
||||||
|
│ │ │ ├── file-001.mp4
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ │ ├── chunk-001
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ │ └── ...
|
||||||
|
│ └── ...
|
||||||
|
└── audio
|
||||||
|
├── observation.audio.laptop
|
||||||
│ ├── chunk-000
|
│ ├── chunk-000
|
||||||
│ │ ├── file-000.mp4
|
│ │ ├── file-000.m4a
|
||||||
│ │ ├── file-001.mp4
|
│ │ ├── file-001.m4a
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
│ ├── chunk-001
|
│ ├── chunk-001
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
│ └── ...
|
│ └── ...
|
||||||
├── observation.images.phone
|
├── observation.audio.phone
|
||||||
│ ├── chunk-000
|
│ ├── chunk-000
|
||||||
│ │ ├── file-000.mp4
|
│ │ ├── file-000.m4a
|
||||||
│ │ ├── file-001.mp4
|
│ │ ├── file-001.m4a
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
│ ├── chunk-001
|
│ ├── chunk-001
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
@@ -169,8 +190,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
|
download_audio (bool, optional): Flag to download the audio. Defaults to True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
|
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
|
||||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||||
@@ -198,6 +221,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
|
self._audio_backend = audio_backend if audio_backend else "torchcodec"
|
||||||
self._batch_encoding_size = batch_encoding_size
|
self._batch_encoding_size = batch_encoding_size
|
||||||
self._vcodec = resolve_vcodec(vcodec)
|
self._vcodec = resolve_vcodec(vcodec)
|
||||||
self._encoder_threads = encoder_threads
|
self._encoder_threads = encoder_threads
|
||||||
@@ -219,6 +243,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episodes=episodes,
|
episodes=episodes,
|
||||||
tolerance_s=tolerance_s,
|
tolerance_s=tolerance_s,
|
||||||
video_backend=self._video_backend,
|
video_backend=self._video_backend,
|
||||||
|
audio_backend=self._audio_backend,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
)
|
)
|
||||||
@@ -227,7 +252,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
if force_cache_sync or not self.reader.try_load():
|
if force_cache_sync or not self.reader.try_load():
|
||||||
if is_valid_version(self.revision):
|
if is_valid_version(self.revision):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self._download(download_videos)
|
self._download(download_videos, download_audio)
|
||||||
self.reader.load_and_activate()
|
self.reader.load_and_activate()
|
||||||
|
|
||||||
# Detect write-mode params for backward compatibility
|
# Detect write-mode params for backward compatibility
|
||||||
@@ -281,6 +306,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episodes=self.episodes,
|
episodes=self.episodes,
|
||||||
tolerance_s=self.tolerance_s,
|
tolerance_s=self.tolerance_s,
|
||||||
video_backend=self._video_backend,
|
video_backend=self._video_backend,
|
||||||
|
audio_backend=self._audio_backend,
|
||||||
delta_timestamps=self.delta_timestamps,
|
delta_timestamps=self.delta_timestamps,
|
||||||
image_transforms=self.image_transforms,
|
image_transforms=self.image_transforms,
|
||||||
)
|
)
|
||||||
@@ -360,6 +386,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self._require_writer("add_frame")
|
self._require_writer("add_frame")
|
||||||
self.writer.add_frame(frame)
|
self.writer.add_frame(frame)
|
||||||
|
|
||||||
|
def add_microphones_recordings(self, microphones: dict) -> None:
|
||||||
|
"""Add microphone recordings to the current episode buffer.
|
||||||
|
|
||||||
|
Delegates to :meth:`DatasetWriter.add_microphones_recordings`.
|
||||||
|
"""
|
||||||
|
self._require_writer("add_microphones_recordings")
|
||||||
|
self.writer.add_microphones_recordings(microphones)
|
||||||
|
|
||||||
def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None:
|
def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None:
|
||||||
"""Save the current episode buffer to disk.
|
"""Save the current episode buffer to disk.
|
||||||
|
|
||||||
@@ -484,6 +518,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
license: str | None = "apache-2.0",
|
license: str | None = "apache-2.0",
|
||||||
tag_version: bool = True,
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
|
push_audio: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
upload_large_folder: bool = False,
|
upload_large_folder: bool = False,
|
||||||
@@ -513,6 +548,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
ignore_patterns = ["images/"]
|
ignore_patterns = ["images/"]
|
||||||
if not push_videos:
|
if not push_videos:
|
||||||
ignore_patterns.append("videos/")
|
ignore_patterns.append("videos/")
|
||||||
|
if not push_audio:
|
||||||
|
ignore_patterns.append("audio/")
|
||||||
|
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
hub_api.create_repo(
|
hub_api.create_repo(
|
||||||
@@ -553,10 +590,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
def _download(self, download_videos: bool = True) -> None:
|
def _download(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||||
"""Downloads the dataset from the given 'repo_id' at the provided version."""
|
"""Downloads the dataset from the given 'repo_id' at the provided version."""
|
||||||
ignore_patterns = None if download_videos else "videos/"
|
ignore_patterns = None if download_videos else "videos/"
|
||||||
files = None
|
files = None
|
||||||
|
ignore_patterns = []
|
||||||
|
if not download_videos:
|
||||||
|
ignore_patterns.append("videos/")
|
||||||
|
if not download_audio:
|
||||||
|
ignore_patterns.append("audio/")
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
# Reader is guaranteed to exist here (created in __init__ before _download)
|
# Reader is guaranteed to exist here (created in __init__ before _download)
|
||||||
files = self.reader.get_episodes_file_paths()
|
files = self.reader.get_episodes_file_paths()
|
||||||
@@ -603,6 +645,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
image_writer_threads: int = 0,
|
image_writer_threads: int = 0,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
batch_encoding_size: int = 1,
|
batch_encoding_size: int = 1,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
metadata_buffer_size: int = 10,
|
metadata_buffer_size: int = 10,
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class ForwardCompatibilityError(CompatibilityError):
|
|||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||||
|
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
|
|
||||||
INFO_PATH = "meta/info.json"
|
INFO_PATH = "meta/info.json"
|
||||||
STATS_PATH = "meta/stats.json"
|
STATS_PATH = "meta/stats.json"
|
||||||
@@ -80,6 +81,7 @@ STATS_PATH = "meta/stats.json"
|
|||||||
EPISODES_DIR = "meta/episodes"
|
EPISODES_DIR = "meta/episodes"
|
||||||
DATA_DIR = "data"
|
DATA_DIR = "data"
|
||||||
VIDEO_DIR = "videos"
|
VIDEO_DIR = "videos"
|
||||||
|
AUDIO_DIR = "audio"
|
||||||
|
|
||||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||||
@@ -87,7 +89,12 @@ DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
|||||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||||
|
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
|
||||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||||
|
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||||
|
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||||
|
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
|
||||||
|
|
||||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||||
|
|||||||
@@ -486,42 +486,42 @@ def encode_video_frames(
|
|||||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||||
|
|
||||||
|
|
||||||
def concatenate_video_files(
|
def concatenate_media_files(
|
||||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Concatenate multiple video files into a single video file using pyav.
|
Concatenate multiple media files (video & audio) into a single media file using pyav.
|
||||||
|
|
||||||
This function takes a list of video input file paths and concatenates them into a single
|
This function takes a list of input media file paths and concatenates them into a single
|
||||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||||
concatenation without re-encoding.
|
concatenation without re-encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
input_media_paths: Ordered list of input media file paths to concatenate.
|
||||||
output_video_path: Path to the output video file.
|
output_media_path: Path to the output media file.
|
||||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
|
||||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
|
||||||
codec, resolution, and frame rate for proper concatenation.
|
codec, resolution, and frame rate for proper concatenation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_video_path = Path(output_video_path)
|
output_media_path = Path(output_media_path)
|
||||||
|
|
||||||
if output_video_path.exists() and not overwrite:
|
if output_media_path.exists() and not overwrite:
|
||||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
|
||||||
return
|
return
|
||||||
|
|
||||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
output_media_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if len(input_video_paths) == 0:
|
if len(input_media_paths) == 0:
|
||||||
raise FileNotFoundError("No input video paths provided.")
|
raise FileNotFoundError("No input media paths provided.")
|
||||||
|
|
||||||
# Create a temporary .ffconcat file to list the input video paths
|
# Create a temporary .ffconcat file to list the input media paths
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||||
for input_path in input_video_paths:
|
for input_path in input_media_paths:
|
||||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||||
tmp_concatenate_file.flush()
|
tmp_concatenate_file.flush()
|
||||||
tmp_concatenate_path = tmp_concatenate_file.name
|
tmp_concatenate_path = tmp_concatenate_file.name
|
||||||
@@ -531,11 +531,12 @@ def concatenate_video_files(
|
|||||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||||
) # safe = 0 allows absolute paths as well as relative paths
|
) # safe = 0 allows absolute paths as well as relative paths
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
|
||||||
tmp_output_video_path = tmp_named_file.name
|
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
|
||||||
|
tmp_output_media_path = tmp_named_file.name
|
||||||
|
|
||||||
output_container = av.open(
|
output_container = av.open(
|
||||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
|
||||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||||
|
|
||||||
# Replicate input streams in output container
|
# Replicate input streams in output container
|
||||||
@@ -550,6 +551,7 @@ def concatenate_video_files(
|
|||||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||||
|
|
||||||
# Demux + remux packets (no re-encode)
|
# Demux + remux packets (no re-encode)
|
||||||
|
last_dts = None
|
||||||
for packet in input_container.demux():
|
for packet in input_container.demux():
|
||||||
# Skip packets from un-mapped streams
|
# Skip packets from un-mapped streams
|
||||||
if packet.stream.index not in stream_map:
|
if packet.stream.index not in stream_map:
|
||||||
@@ -558,6 +560,16 @@ def concatenate_video_files(
|
|||||||
# Skip demux flushing packets
|
# Skip demux flushing packets
|
||||||
if packet.dts is None:
|
if packet.dts is None:
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
# Enforce strictly increasing decoding timestamps (DTS)
|
||||||
|
if last_dts is not None and packet.dts <= last_dts:
|
||||||
|
shift = last_dts - packet.dts + 1
|
||||||
|
packet.dts += shift
|
||||||
|
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
|
||||||
|
logging.warning(
|
||||||
|
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
|
||||||
|
)
|
||||||
|
last_dts = packet.dts
|
||||||
|
|
||||||
output_stream = stream_map[packet.stream.index]
|
output_stream = stream_map[packet.stream.index]
|
||||||
packet.stream = output_stream
|
packet.stream = output_stream
|
||||||
@@ -565,7 +577,7 @@ def concatenate_video_files(
|
|||||||
|
|
||||||
input_container.close()
|
input_container.close()
|
||||||
output_container.close()
|
output_container.close()
|
||||||
shutil.move(tmp_output_video_path, output_video_path)
|
shutil.move(tmp_output_media_path, output_media_path)
|
||||||
Path(tmp_concatenate_path).unlink()
|
Path(tmp_concatenate_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
@@ -947,38 +959,6 @@ with warnings.catch_warnings():
|
|||||||
register_feature(VideoFrame, "VideoFrame")
|
register_feature(VideoFrame, "VideoFrame")
|
||||||
|
|
||||||
|
|
||||||
def get_audio_info(video_path: Path | str) -> dict:
|
|
||||||
# Set logging level
|
|
||||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
||||||
|
|
||||||
# Getting audio stream information
|
|
||||||
audio_info = {}
|
|
||||||
with av.open(str(video_path), "r") as audio_file:
|
|
||||||
try:
|
|
||||||
audio_stream = audio_file.streams.audio[0]
|
|
||||||
except IndexError:
|
|
||||||
# Reset logging level
|
|
||||||
av.logging.restore_default_callback()
|
|
||||||
return {"has_audio": False}
|
|
||||||
|
|
||||||
audio_info["audio.channels"] = audio_stream.channels
|
|
||||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
|
||||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
|
||||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
|
||||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
|
||||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
|
||||||
# In an ideal loseless case : fixed number of bits per sample.
|
|
||||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
|
||||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
|
||||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
|
||||||
audio_info["has_audio"] = True
|
|
||||||
|
|
||||||
# Reset logging level
|
|
||||||
av.logging.restore_default_callback()
|
|
||||||
|
|
||||||
return audio_info
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(video_path: Path | str) -> dict:
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
# Set logging level
|
# Set logging level
|
||||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||||
@@ -1008,9 +988,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
|||||||
# Reset logging level
|
# Reset logging level
|
||||||
av.logging.restore_default_callback()
|
av.logging.restore_default_callback()
|
||||||
|
|
||||||
# Adding audio stream information
|
|
||||||
video_info.update(**get_audio_info(video_path))
|
|
||||||
|
|
||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|
||||||
@@ -1025,22 +1002,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
|||||||
raise ValueError("Unknown format")
|
raise ValueError("Unknown format")
|
||||||
|
|
||||||
|
|
||||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
|
||||||
"""
|
"""
|
||||||
Get the duration of a video file in seconds using PyAV.
|
Get the duration of a media file (video & audio) in seconds using PyAV.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: Path to the video file.
|
media_path: Path to the media file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Duration of the video in seconds.
|
Duration of the media file in seconds.
|
||||||
"""
|
"""
|
||||||
with av.open(str(video_path)) as container:
|
with av.open(str(media_path)) as container:
|
||||||
# Get the first video stream
|
# Get the first stream
|
||||||
video_stream = container.streams.video[0]
|
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
|
||||||
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
||||||
if video_stream.duration is not None:
|
if stream.duration is not None:
|
||||||
duration = float(video_stream.duration * video_stream.time_base)
|
duration = float(stream.duration * stream.time_base)
|
||||||
else:
|
else:
|
||||||
# Fallback to container duration if stream duration is not available
|
# Fallback to container duration if stream duration is not available
|
||||||
duration = float(container.duration / av.time_base)
|
duration = float(container.duration / av.time_base)
|
||||||
@@ -1049,12 +1026,12 @@ def get_video_duration_in_s(video_path: Path | str) -> float:
|
|||||||
|
|
||||||
class VideoEncodingManager:
|
class VideoEncodingManager:
|
||||||
"""
|
"""
|
||||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
|
||||||
|
|
||||||
This manager handles:
|
This manager handles:
|
||||||
- Batch encoding for any remaining episodes when recording interrupted
|
- Batch encoding for any remaining episodes when recording interrupted
|
||||||
- Cleaning up temporary image files from interrupted episodes
|
- Cleaning up temporary image and audio files from interrupted episodes
|
||||||
- Removing empty image directories
|
- Removing empty image and audio directories
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: The LeRobotDataset instance
|
dataset: The LeRobotDataset instance
|
||||||
@@ -1091,4 +1068,16 @@ class VideoEncodingManager:
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||||
|
|
||||||
|
# Clean up any remaining audio directory if it's empty
|
||||||
|
audio_dir = self.dataset.root / "raw_audio"
|
||||||
|
# Check for any remaining WAV files
|
||||||
|
wav_files = list(audio_dir.rglob("*.wav"))
|
||||||
|
if len(wav_files) == 0:
|
||||||
|
# Only remove the raw_audio directory if no WAV files remain
|
||||||
|
if audio_dir.exists():
|
||||||
|
shutil.rmtree(audio_dir)
|
||||||
|
logging.debug("Cleaned up empty audio directory")
|
||||||
|
else:
|
||||||
|
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
|
||||||
|
|
||||||
return False # Don't suppress the original exception
|
return False # Don't suppress the original exception
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configs import MicrophoneConfig
|
||||||
|
from .microphone import Microphone
|
||||||
|
from .utils import make_microphones_from_configs
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class MicrophoneConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
|
sample_rate: int | None = None
|
||||||
|
channels: list[int] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
return self.get_choice_name(self.__class__)
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Barrier
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .configs import MicrophoneConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Microphone(abc.ABC):
|
||||||
|
"""Base class for microphone implementations.
|
||||||
|
|
||||||
|
Defines a standard interface for microphone operations across different backends.
|
||||||
|
Subclasses must implement all abstract methods.
|
||||||
|
|
||||||
|
Manages basic microphone properties (sample rate, channels) and core operations:
|
||||||
|
- Connection/disconnection
|
||||||
|
- Start/stop recording
|
||||||
|
- Audio chunk reading
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sample_rate (int | None): Configured sample rate in Hz
|
||||||
|
channels (list[int] | None): List of channel numbers to record
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MyMicrophone(Microphone):
|
||||||
|
def __init__(self, config): ...
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool: ...
|
||||||
|
def connect(self): ...
|
||||||
|
# Plus other required methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MicrophoneConfig):
|
||||||
|
"""Initialize the microphone with the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Microphone configuration containing sample rate and channels.
|
||||||
|
"""
|
||||||
|
self.sample_rate: int | None = config.sample_rate
|
||||||
|
self.channels: list[int] | None = config.channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if the microphone is currently connected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the microphone is connected and ready to start recording,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_recording(self) -> bool:
|
||||||
|
"""Check if the microphone is currently recording.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the microphone is recording, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_writing(self) -> bool:
|
||||||
|
"""Check if the microphone is currently writing to a file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the microphone is writing to a file, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def find_microphones() -> list[dict[str, Any]]:
|
||||||
|
"""Detects available microphones connected to the system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: A list of dictionaries,
|
||||||
|
where each dictionary contains information about a detected microphone.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Establish connection to the microphone."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def start_recording(
|
||||||
|
self,
|
||||||
|
output_file: str | Path | None = None,
|
||||||
|
multiprocessing: bool | None = False,
|
||||||
|
overwrite: bool | None = True,
|
||||||
|
barrier: Barrier | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Start recording audio from the microphone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Optional path to save the recorded audio.
|
||||||
|
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||||
|
overwrite: If True, overwrites existing files at output_file path.
|
||||||
|
barrier: If not None, ensures that multiple microphones start recording at the same time.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def read(self) -> np.ndarray:
|
||||||
|
"""Capture and return a single audio chunk from the microphone.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Captured audio chunk as a numpy array.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def stop_recording(self) -> None:
|
||||||
|
"""Stop recording audio from the microphone."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""Disconnect the microphone and release any resources."""
|
||||||
|
pass
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||||
|
from .microphone_portaudio import PortAudioMicrophone
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..configs import MicrophoneConfig
|
||||||
|
|
||||||
|
|
||||||
|
@MicrophoneConfig.register_subclass("portaudio")
|
||||||
|
@dataclass
|
||||||
|
class PortAudioMicrophoneConfig(MicrophoneConfig):
|
||||||
|
"""Configuration class for PortAudio-based microphone devices.
|
||||||
|
|
||||||
|
This class provides configuration options for microphones accessed through PortAudio with the sounddevice Python package.
|
||||||
|
including device index, sample rate and channels.
|
||||||
|
|
||||||
|
Example configurations:
|
||||||
|
```python
|
||||||
|
# Basic configurations
|
||||||
|
PortAudioMicrophoneConfig(0, 16000, [1]) # Device index 0, 16000Hz, mono
|
||||||
|
PortAudioMicrophoneConfig(1, 44100, [1, 2]) # Device index 1, 44100Hz, stereo
|
||||||
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
microphone_index: Device index for the microphone.
|
||||||
|
sample_rate: Sample rate in Hz for the microphone.
|
||||||
|
channels: List of channel numbers to use for the microphone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
microphone_index: int
|
||||||
@@ -0,0 +1,394 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from threading import Event, Thread
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sounddevice import PortAudioError
|
||||||
|
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
|
|
||||||
|
# --- Interface definitions for InputStream ---
|
||||||
|
class IInputStream(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samplerate: float | None = None,
|
||||||
|
blocksize: int | None = None,
|
||||||
|
device: int | str | None = None,
|
||||||
|
channels: int | None = None,
|
||||||
|
dtype: str | np.dtype | None = None,
|
||||||
|
latency: float | str | None = None,
|
||||||
|
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def stop(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ISounddeviceSDK(abc.ABC):
|
||||||
|
"""Interface defining the contract for the Sounddevice SDK."""
|
||||||
|
|
||||||
|
InputStream: type[IInputStream]
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# --- Real SDK Adapter ---
|
||||||
|
|
||||||
|
|
||||||
|
class SounddeviceSDKAdapter(ISounddeviceSDK):
|
||||||
|
"""Adapts the real sounddevice library to the ISounddeviceSDK interface."""
|
||||||
|
|
||||||
|
_sounddevice = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
import sounddevice
|
||||||
|
|
||||||
|
SounddeviceSDKAdapter._sounddevice = sounddevice
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("sounddevice library not found") from e
|
||||||
|
|
||||||
|
# --- Inner Class Implementation ---
|
||||||
|
class RealInputStream(IInputStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samplerate: int | None = None,
|
||||||
|
blocksize: int | None = None,
|
||||||
|
device: int | None = None,
|
||||||
|
channels: int | None = None,
|
||||||
|
dtype: str | np.dtype | None = None,
|
||||||
|
latency: float | str | None = None,
|
||||||
|
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||||
|
):
|
||||||
|
import sounddevice
|
||||||
|
|
||||||
|
self._input_stream = sounddevice.InputStream(
|
||||||
|
samplerate=samplerate,
|
||||||
|
blocksize=blocksize,
|
||||||
|
device=device,
|
||||||
|
channels=channels,
|
||||||
|
dtype=dtype,
|
||||||
|
latency=latency,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
self._input_stream.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._input_stream.stop()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._input_stream.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._input_stream.stop()
|
||||||
|
self._input_stream.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active(self) -> bool:
|
||||||
|
return self._input_stream.active
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stopped(self) -> bool:
|
||||||
|
return self._input_stream.stopped
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._input_stream.closed
|
||||||
|
|
||||||
|
InputStream = RealInputStream
|
||||||
|
|
||||||
|
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
return SounddeviceSDKAdapter._sounddevice.query_devices(device, kind)
|
||||||
|
|
||||||
|
|
||||||
|
# Emulates a 48kHz stereo microphone
|
||||||
|
VALID_DTYPE = {
|
||||||
|
"float32",
|
||||||
|
"int32",
|
||||||
|
"int16",
|
||||||
|
"int8",
|
||||||
|
"uint8",
|
||||||
|
np.float32,
|
||||||
|
np.int32,
|
||||||
|
np.int16,
|
||||||
|
np.int8,
|
||||||
|
np.uint8,
|
||||||
|
}
|
||||||
|
VALID_LATENCY = {"low", "high"}
|
||||||
|
|
||||||
|
VALID_DEVICES = [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"name": "Built-in Microphone",
|
||||||
|
"hostapi": 0,
|
||||||
|
"max_input_channels": 2,
|
||||||
|
"max_output_channels": 0,
|
||||||
|
"default_low_input_latency": 0.01,
|
||||||
|
"default_low_output_latency": 0.001,
|
||||||
|
"default_high_input_latency": 0.1,
|
||||||
|
"default_high_output_latency": 0.01,
|
||||||
|
"default_samplerate": 48000.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"name": "Built-in Output",
|
||||||
|
"hostapi": 0,
|
||||||
|
"max_input_channels": 0,
|
||||||
|
"max_output_channels": 2,
|
||||||
|
"default_low_input_latency": 0.04,
|
||||||
|
"default_low_output_latency": 0.04,
|
||||||
|
"default_high_input_latency": 0.12,
|
||||||
|
"default_high_output_latency": 0.12,
|
||||||
|
"default_samplerate": 48000.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 2,
|
||||||
|
"name": "USB Audio Device",
|
||||||
|
"hostapi": 0,
|
||||||
|
"max_input_channels": 1,
|
||||||
|
"max_output_channels": 0,
|
||||||
|
"default_low_input_latency": 0.03,
|
||||||
|
"default_low_output_latency": 0.01,
|
||||||
|
"default_high_input_latency": 0.04,
|
||||||
|
"default_high_output_latency": 0.03,
|
||||||
|
"default_samplerate": 16000.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# -- Fake SDK Adapter ---
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSounddeviceSDKAdapter(ISounddeviceSDK):
|
||||||
|
"""Implements the ISounddeviceSDK interface with fake behaviour for testing."""
|
||||||
|
|
||||||
|
# --- Inner Class Implementation ---
|
||||||
|
class FakeInputStream(IInputStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samplerate: float | None = None,
|
||||||
|
blocksize: int | None = None,
|
||||||
|
device: int | str | None = None,
|
||||||
|
channels: int | None = None,
|
||||||
|
dtype: str | None = None,
|
||||||
|
latency: str | None = None,
|
||||||
|
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||||
|
):
|
||||||
|
self.samplerate = samplerate
|
||||||
|
self.blocksize = blocksize
|
||||||
|
self.device = device
|
||||||
|
self.channels = channels
|
||||||
|
self.dtype = dtype
|
||||||
|
self.latency = latency
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
self._validate_settings()
|
||||||
|
|
||||||
|
self._active = False
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
if self.callback is not None:
|
||||||
|
self._streaming_thread = Thread(target=self._streaming_loop, daemon=True)
|
||||||
|
self._streaming_thread_stop_event = Event()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active(self) -> bool:
|
||||||
|
"""True when the stream is active, False otherwise."""
|
||||||
|
return self._active
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stopped(self) -> bool:
|
||||||
|
"""True when the stream is stopped, False otherwise."""
|
||||||
|
return not self._active
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
"""True after a call to close(), False otherwise."""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def _get_device_info(self):
|
||||||
|
"""Returns the device info for the device."""
|
||||||
|
for device in VALID_DEVICES:
|
||||||
|
if (isinstance(self.device, int) and device["index"] == self.device) or (
|
||||||
|
isinstance(self.device, str) and device["name"] == self.device
|
||||||
|
):
|
||||||
|
return device
|
||||||
|
raise PortAudioError(f"No input device matching {self.device}")
|
||||||
|
|
||||||
|
def _validate_device(self):
|
||||||
|
"""Validates the device against the valid devices."""
|
||||||
|
valid_device_indices = [device["index"] for device in VALID_DEVICES]
|
||||||
|
valid_device_names = [device["name"] for device in VALID_DEVICES]
|
||||||
|
|
||||||
|
if self.device is not None:
|
||||||
|
if isinstance(self.device, (int, str)):
|
||||||
|
# Check if device index is valid
|
||||||
|
if isinstance(self.device, int) and self.device not in valid_device_indices:
|
||||||
|
raise PortAudioError(f"Error querying device {self.device}")
|
||||||
|
|
||||||
|
# Check if device name is valid
|
||||||
|
if isinstance(self.device, str) and self.device not in valid_device_names:
|
||||||
|
raise PortAudioError(f"No input device matching {self.device}")
|
||||||
|
else:
|
||||||
|
raise PortAudioError(f"Device must be int or str, got {type(self.device)}")
|
||||||
|
else:
|
||||||
|
# Default to first input device
|
||||||
|
input_devices = [d for d in VALID_DEVICES if d["max_input_channels"] > 0]
|
||||||
|
if input_devices:
|
||||||
|
self.device = input_devices[0]["index"]
|
||||||
|
|
||||||
|
def _validate_samplerate(self):
|
||||||
|
"""Validates the samplerate against the device's maximum samplerate."""
|
||||||
|
device_info = self._get_device_info()
|
||||||
|
if self.samplerate is None:
|
||||||
|
self.samplerate = device_info["default_samplerate"]
|
||||||
|
elif self.samplerate > device_info["default_samplerate"] or self.samplerate < 1000:
|
||||||
|
raise PortAudioError("Error opening InputStream: Invalid sample rate")
|
||||||
|
|
||||||
|
def _validate_channels(self):
|
||||||
|
"""Validates the channels against the device's maximum channels."""
|
||||||
|
device_info = self._get_device_info()
|
||||||
|
if self.channels is None:
|
||||||
|
self.channels = device_info["max_input_channels"]
|
||||||
|
elif self.channels > device_info["max_input_channels"] or self.channels < 1:
|
||||||
|
raise PortAudioError("Error opening InputStream: Invalid number of channels")
|
||||||
|
|
||||||
|
def _validate_dtype(self):
|
||||||
|
"""Validates the dtype against the valid dtypes."""
|
||||||
|
if self.dtype is not None:
|
||||||
|
if self.dtype not in VALID_DTYPE:
|
||||||
|
raise PortAudioError("Invalid input sample format")
|
||||||
|
else:
|
||||||
|
self.dtype = "float32" # Default dtype
|
||||||
|
|
||||||
|
def _validate_latency(self):
|
||||||
|
"""Validates the latency against the valid latencies."""
|
||||||
|
if self.latency is not None:
|
||||||
|
if self.latency not in VALID_LATENCY:
|
||||||
|
raise PortAudioError("Invalid latency")
|
||||||
|
else:
|
||||||
|
self.latency = "low" # Default latency
|
||||||
|
|
||||||
|
if isinstance(self.latency, str):
|
||||||
|
device_info = self._get_device_info()
|
||||||
|
if self.latency == "low":
|
||||||
|
self.latency = device_info["default_low_input_latency"]
|
||||||
|
elif self.latency == "high":
|
||||||
|
self.latency = device_info["default_high_input_latency"]
|
||||||
|
|
||||||
|
def _validate_settings(self):
|
||||||
|
"""Validates the input parameters against available devices and valid options."""
|
||||||
|
self._validate_device()
|
||||||
|
self._validate_samplerate()
|
||||||
|
self._validate_channels()
|
||||||
|
self._validate_dtype()
|
||||||
|
self._validate_latency()
|
||||||
|
|
||||||
|
def _simulated_audio_data(self) -> np.ndarray:
|
||||||
|
"""Generates a simulated audio signal for testing purposes with proper value ranges."""
|
||||||
|
duration_samples = int(self.samplerate * self.latency)
|
||||||
|
|
||||||
|
# Generate output according to dtype
|
||||||
|
if self.dtype in {"float32", np.float32}:
|
||||||
|
# Generate values between -1 and 1 for float32
|
||||||
|
data = np.random.uniform(-1.0, 1.0, (duration_samples, self.channels)).astype(self.dtype)
|
||||||
|
else:
|
||||||
|
# Use np.iinfo to get proper range for integer types
|
||||||
|
info = np.iinfo(self.dtype)
|
||||||
|
data = np.random.randint(
|
||||||
|
info.min, info.max + 1, (duration_samples, self.channels), dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _streaming_loop(self):
|
||||||
|
if self.callback is not None:
|
||||||
|
while not self._streaming_thread_stop_event.is_set():
|
||||||
|
precise_sleep(self.latency)
|
||||||
|
tmp_data = self._simulated_audio_data()
|
||||||
|
self.callback(
|
||||||
|
tmp_data,
|
||||||
|
len(tmp_data),
|
||||||
|
time.perf_counter(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the fake input stream."""
|
||||||
|
if not self.active and self.callback is not None:
|
||||||
|
self._streaming_thread.start()
|
||||||
|
self._active = True
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the fake input stream."""
|
||||||
|
if self.callback is not None:
|
||||||
|
self._streaming_thread_stop_event.set()
|
||||||
|
self._streaming_thread.join()
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the fake input stream."""
|
||||||
|
if self.active and self.callback is not None:
|
||||||
|
self.stop()
|
||||||
|
self._active = False
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
InputStream = FakeInputStream
|
||||||
|
|
||||||
|
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
"""Returns a realistic list of audio devices including speakers and microphones."""
|
||||||
|
if device is not None:
|
||||||
|
# Return specific device
|
||||||
|
for valid_device in VALID_DEVICES:
|
||||||
|
if (isinstance(device, int) and valid_device["index"] == device) or (
|
||||||
|
isinstance(device, str) and valid_device["name"] == device
|
||||||
|
):
|
||||||
|
return valid_device
|
||||||
|
raise PortAudioError(f"Error querying device {device}")
|
||||||
|
|
||||||
|
elif kind is not None:
|
||||||
|
for valid_device in VALID_DEVICES:
|
||||||
|
if (
|
||||||
|
valid_device["max_input_channels"] > 0
|
||||||
|
and kind == "input"
|
||||||
|
or valid_device["max_output_channels"] > 0
|
||||||
|
and kind == "output"
|
||||||
|
):
|
||||||
|
return valid_device
|
||||||
|
raise PortAudioError(f"No {kind} device found")
|
||||||
|
|
||||||
|
return VALID_DEVICES
|
||||||
@@ -0,0 +1,566 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Provides the PortAudioMicrophone class for capturing audio from microphones using the PortAudio library through the sounddevice Python package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from multiprocessing import (
|
||||||
|
Event as process_Event,
|
||||||
|
JoinableQueue as process_Queue,
|
||||||
|
Process,
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty
|
||||||
|
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundfile import SoundFile
|
||||||
|
|
||||||
|
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
|
||||||
|
from lerobot.utils.errors import (
|
||||||
|
DeviceAlreadyConnectedError,
|
||||||
|
DeviceAlreadyRecordingError,
|
||||||
|
DeviceNotConnectedError,
|
||||||
|
DeviceNotRecordingError,
|
||||||
|
)
|
||||||
|
from lerobot.utils.shared_array import SharedArray
|
||||||
|
|
||||||
|
from ..microphone import Microphone
|
||||||
|
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PortAudioMicrophone(Microphone):
|
||||||
|
"""
|
||||||
|
The PortAudioMicrophone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||||
|
|
||||||
|
A PortAudioMicrophone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||||
|
|
||||||
|
Example of usage:
|
||||||
|
```python
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import PortAudioMicrophoneConfig
|
||||||
|
|
||||||
|
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
|
||||||
|
microphone = PortAudioMicrophone(config)
|
||||||
|
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording("some/output/file.wav")
|
||||||
|
...
|
||||||
|
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||||
|
...
|
||||||
|
microphone.stop_recording()
|
||||||
|
microphone.disconnect()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: PortAudioMicrophoneConfig, sounddevice_sdk: ISounddeviceSDK = None):
|
||||||
|
"""
|
||||||
|
Initializes the PortAudioMicrophone instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration settings for the microphone.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if sounddevice_sdk is None:
|
||||||
|
self.sounddevice_sdk = SounddeviceSDKAdapter()
|
||||||
|
else:
|
||||||
|
self.sounddevice_sdk = sounddevice_sdk
|
||||||
|
|
||||||
|
# Microphone index
|
||||||
|
self.microphone_index = config.microphone_index
|
||||||
|
|
||||||
|
# Input audio recording process and events
|
||||||
|
self.record_process = None
|
||||||
|
self.record_stop_event = process_Event()
|
||||||
|
self.record_start_event = process_Event()
|
||||||
|
self.record_close_event = process_Event()
|
||||||
|
self.record_is_started_event = process_Event()
|
||||||
|
self.audio_callback_start_event = process_Event()
|
||||||
|
|
||||||
|
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||||
|
self.write_queue = process_Queue()
|
||||||
|
|
||||||
|
# SharedArray to store audio from the recording process.
|
||||||
|
self.read_shared_array = None
|
||||||
|
self.local_read_shared_array = None
|
||||||
|
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||||
|
self.write_thread = None
|
||||||
|
self.write_stop_event = None
|
||||||
|
self.write_is_started_event = None
|
||||||
|
|
||||||
|
self.logs = {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}({self.microphone_index})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.record_process is not None and self.record_process.is_alive()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_recording(self) -> bool:
|
||||||
|
return self.record_is_started_event.is_set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_writing(self) -> bool:
|
||||||
|
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_microphones(
|
||||||
|
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
|
||||||
|
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Detects available microphones connected to the system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The device to find microphones for. If None, all microphones are found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: A list of dictionaries,
|
||||||
|
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if sounddevice_sdk is None:
|
||||||
|
sounddevice_sdk = SounddeviceSDKAdapter()
|
||||||
|
|
||||||
|
found_microphones_info = []
|
||||||
|
|
||||||
|
devices = sounddevice_sdk.query_devices()
|
||||||
|
for d in devices:
|
||||||
|
if d["max_input_channels"] > 0:
|
||||||
|
microphone_info = {
|
||||||
|
"index": d["index"],
|
||||||
|
"name": d["name"],
|
||||||
|
"sample_rate": int(d["default_samplerate"]),
|
||||||
|
"channels": np.arange(1, d["max_input_channels"] + 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
if device is None or (
|
||||||
|
(isinstance(device, int) and d["index"] == device)
|
||||||
|
or (isinstance(device, str) and d["name"] == device)
|
||||||
|
):
|
||||||
|
found_microphones_info.append(microphone_info)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
if len(found_microphones_info) == 0:
|
||||||
|
raise RuntimeError(f"No microphone found for device {device}")
|
||||||
|
else:
|
||||||
|
return found_microphones_info[0]
|
||||||
|
|
||||||
|
if len(found_microphones_info) == 0:
|
||||||
|
logger.warning("No microphone found !")
|
||||||
|
|
||||||
|
return found_microphones_info
|
||||||
|
|
||||||
|
def _configure_capture_settings(self) -> None:
|
||||||
|
"""
|
||||||
|
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
|
||||||
|
|
||||||
|
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If one of the specified settings is not compatible with the microphone.
|
||||||
|
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(
|
||||||
|
f"Cannot configure settings for {self} as it is already connected."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._validate_microphone_index()
|
||||||
|
self._validate_sample_rate()
|
||||||
|
self._validate_channels()
|
||||||
|
|
||||||
|
def _validate_microphone_index(self) -> None:
|
||||||
|
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _validate_sample_rate(self) -> None:
|
||||||
|
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||||
|
|
||||||
|
actual_sample_rate = PortAudioMicrophone.find_microphones(
|
||||||
|
self.microphone_index, self.sounddevice_sdk
|
||||||
|
)["sample_rate"]
|
||||||
|
|
||||||
|
if self.sample_rate is not None:
|
||||||
|
try:
|
||||||
|
self.sample_rate = int(self.sample_rate)
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.sample_rate < actual_sample_rate:
|
||||||
|
logger.warning(
|
||||||
|
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.sample_rate = actual_sample_rate
|
||||||
|
|
||||||
|
def _validate_channels(self) -> None:
|
||||||
|
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||||
|
|
||||||
|
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
|
||||||
|
"channels"
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.channels is not None and len(self.channels) > 0:
|
||||||
|
if not all(channel in actual_channels for channel in self.channels):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.channels = actual_channels
|
||||||
|
|
||||||
|
# Get channels index instead of number for slicing
|
||||||
|
self.channels_index = np.array(self.channels) - 1
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||||
|
|
||||||
|
self._configure_capture_settings()
|
||||||
|
|
||||||
|
# Create or reset queue and shared array
|
||||||
|
self.read_shared_array = SharedArray(
|
||||||
|
shape=(self.sample_rate * 10, len(self.channels)),
|
||||||
|
dtype=np.dtype("float32"),
|
||||||
|
)
|
||||||
|
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||||
|
self.write_queue = process_Queue()
|
||||||
|
|
||||||
|
# Reset events
|
||||||
|
self.record_start_event.clear()
|
||||||
|
self.record_stop_event.clear()
|
||||||
|
self.record_close_event.clear()
|
||||||
|
self.record_is_started_event.clear()
|
||||||
|
self.audio_callback_start_event.clear()
|
||||||
|
|
||||||
|
# Create and start an audio input stream with a recording callback
|
||||||
|
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||||
|
process_init_event = process_Event()
|
||||||
|
self.record_process = Process(
|
||||||
|
target=self._record_process,
|
||||||
|
args=(
|
||||||
|
self.microphone_index,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
process_init_event,
|
||||||
|
self.record_start_event,
|
||||||
|
self.record_stop_event,
|
||||||
|
self.record_close_event,
|
||||||
|
self.record_is_started_event,
|
||||||
|
self.audio_callback_start_event,
|
||||||
|
self.write_queue,
|
||||||
|
self.read_shared_array,
|
||||||
|
self.sounddevice_sdk,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.record_process.daemon = True
|
||||||
|
self.record_process.start()
|
||||||
|
|
||||||
|
is_init = process_init_event.wait(
|
||||||
|
timeout=5.0
|
||||||
|
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||||
|
if not self.is_connected or not is_init:
|
||||||
|
raise RuntimeError(f"Error connecting microphone {self.microphone_index}.")
|
||||||
|
|
||||||
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""
|
||||||
|
Disconnects the microphone and stops the recording.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
|
||||||
|
if self.is_recording:
|
||||||
|
self.stop_recording()
|
||||||
|
|
||||||
|
self.record_close_event.set()
|
||||||
|
self.read_shared_array.delete()
|
||||||
|
self.write_queue.close()
|
||||||
|
self.record_process.join()
|
||||||
|
|
||||||
|
if self.is_connected:
|
||||||
|
raise RuntimeError(f"Error disconnecting microphone {self.microphone_index}.")
|
||||||
|
|
||||||
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
def _read(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Thread/Process-safe callback to read available audio data
|
||||||
|
"""
|
||||||
|
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||||
|
|
||||||
|
def read(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
if not self.is_recording:
|
||||||
|
raise RuntimeError(f"Microphone {self.microphone_index} is not recording.")
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
audio_readings = self._read()
|
||||||
|
|
||||||
|
# log the number of seconds it took to read the audio chunk
|
||||||
|
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# log the utc time at which the audio chunk was received
|
||||||
|
self.logs["timestamp_utc"] = time.perf_counter()
|
||||||
|
|
||||||
|
return audio_readings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _record_process(
|
||||||
|
microphone_index,
|
||||||
|
sample_rate,
|
||||||
|
channels,
|
||||||
|
process_init_event,
|
||||||
|
record_start_event,
|
||||||
|
record_stop_event,
|
||||||
|
record_close_event,
|
||||||
|
record_is_started_event,
|
||||||
|
audio_callback_start_event,
|
||||||
|
write_queue,
|
||||||
|
read_shared_array,
|
||||||
|
sounddevice_sdk,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Process callback used to create an unpickable sounddevice audio input stream with a recording callback and start, stop and close it based on multiprocessing events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
channels_index = np.array(channels) - 1
|
||||||
|
local_read_shared_array = read_shared_array.get_local_array()
|
||||||
|
|
||||||
|
def audio_callback(indata, frames, timestamp, status) -> None:
|
||||||
|
"""
|
||||||
|
Low-level sounddevice callback.
|
||||||
|
"""
|
||||||
|
if status:
|
||||||
|
logger.warning(status)
|
||||||
|
if audio_callback_start_event.is_set():
|
||||||
|
write_queue.put_nowait(indata[:, channels_index])
|
||||||
|
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||||
|
|
||||||
|
# Create the audio stream
|
||||||
|
# InputStream must be instantiated in the process as it is not pickable.
|
||||||
|
stream = sounddevice_sdk.InputStream(
|
||||||
|
device=microphone_index,
|
||||||
|
samplerate=sample_rate,
|
||||||
|
channels=max(channels),
|
||||||
|
dtype="float32",
|
||||||
|
blocksize=0, # Varying input buffer length, but no additional latency
|
||||||
|
latency="low", # Low latency mode (not enabled by default !)
|
||||||
|
# never_drop_input=True, # Disabled as it generates an error for some devices
|
||||||
|
callback=audio_callback,
|
||||||
|
)
|
||||||
|
process_init_event.set()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start_flag = record_start_event.wait(timeout=0.1)
|
||||||
|
if record_close_event.is_set():
|
||||||
|
break
|
||||||
|
elif not start_flag:
|
||||||
|
continue
|
||||||
|
stream.start()
|
||||||
|
record_is_started_event.set()
|
||||||
|
record_stop_event.wait()
|
||||||
|
stream.stop() # stream.stop() waits for all buffers to be processed, stream.abort() flushes the buffers !
|
||||||
|
record_is_started_event.clear()
|
||||||
|
stream.close()
|
||||||
|
|
||||||
|
def start_recording(
|
||||||
|
self,
|
||||||
|
output_file: str | None = None,
|
||||||
|
multiprocessing: bool | None = False,
|
||||||
|
overwrite: bool | None = True,
|
||||||
|
barrier: Barrier | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
if self.is_recording:
|
||||||
|
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||||
|
|
||||||
|
# Reset queue and shared memory
|
||||||
|
self.read_shared_array.reset()
|
||||||
|
self._clear_queue(self.write_queue)
|
||||||
|
|
||||||
|
# Reset stop event
|
||||||
|
self.record_stop_event.clear()
|
||||||
|
|
||||||
|
# Write recordings into a file if output_file is provided
|
||||||
|
if output_file is not None:
|
||||||
|
output_file = Path(output_file)
|
||||||
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if output_file.exists():
|
||||||
|
if overwrite:
|
||||||
|
output_file.unlink()
|
||||||
|
else:
|
||||||
|
raise FileExistsError(
|
||||||
|
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||||
|
)
|
||||||
|
|
||||||
|
if multiprocessing:
|
||||||
|
self.write_stop_event = process_Event()
|
||||||
|
self.write_is_started_event = process_Event()
|
||||||
|
self.write_thread = Process(
|
||||||
|
target=PortAudioMicrophone._write_loop,
|
||||||
|
args=(
|
||||||
|
self.write_queue,
|
||||||
|
self.write_stop_event,
|
||||||
|
self.write_is_started_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.write_stop_event = thread_Event()
|
||||||
|
self.write_is_started_event = thread_Event()
|
||||||
|
self.write_thread = Thread(
|
||||||
|
target=PortAudioMicrophone._write_loop,
|
||||||
|
args=(
|
||||||
|
self.write_queue,
|
||||||
|
self.write_stop_event,
|
||||||
|
self.write_is_started_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.write_thread.daemon = True
|
||||||
|
self.write_thread.start()
|
||||||
|
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||||
|
|
||||||
|
self.record_start_event.set() # Start the input audio stream process
|
||||||
|
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||||
|
|
||||||
|
if barrier is not None:
|
||||||
|
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||||
|
|
||||||
|
self.audio_callback_start_event.set()
|
||||||
|
|
||||||
|
if not self.is_recording:
|
||||||
|
raise RuntimeError(f"Error starting recording for microphone {self.microphone_index}.")
|
||||||
|
if output_file is not None and not self.is_writing:
|
||||||
|
raise RuntimeError(f"Error starting writing for microphone {self.microphone_index}.")
|
||||||
|
|
||||||
|
def stop_recording(self) -> None:
|
||||||
|
"""
|
||||||
|
Stops the recording of the microphones.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
if not self.is_recording:
|
||||||
|
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||||
|
|
||||||
|
self.audio_callback_start_event.clear()
|
||||||
|
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||||
|
self.record_stop_event.set()
|
||||||
|
|
||||||
|
# Wait for the stream to be stopped (might lead to race condition if the stream is not properly stopped on array reset and queue clearing)
|
||||||
|
timeout = 1.0
|
||||||
|
while self.is_recording and timeout > 0:
|
||||||
|
time.sleep(0.01)
|
||||||
|
timeout -= 0.01
|
||||||
|
|
||||||
|
self.read_shared_array.reset()
|
||||||
|
self._clear_queue(self.write_queue, join_queue=True)
|
||||||
|
|
||||||
|
if self.is_writing:
|
||||||
|
self.write_stop_event.set()
|
||||||
|
self.write_thread.join()
|
||||||
|
|
||||||
|
if self.is_recording:
|
||||||
|
raise RuntimeError(f"Error stopping recording for microphone {self.microphone_index}.")
|
||||||
|
if self.is_writing:
|
||||||
|
raise RuntimeError(f"Error stopping writing for microphone {self.microphone_index}.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_loop(
|
||||||
|
queue,
|
||||||
|
write_stop_event: Event,
|
||||||
|
write_is_started_event: Event,
|
||||||
|
sample_rate: int,
|
||||||
|
channels: list[int],
|
||||||
|
output_file: Path,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Thread/Process-safe loop to write audio data into a file.
|
||||||
|
"""
|
||||||
|
# Can only be run on a single process/thread for file writing safety
|
||||||
|
with SoundFile(
|
||||||
|
output_file,
|
||||||
|
mode="w",
|
||||||
|
samplerate=sample_rate,
|
||||||
|
channels=len(channels),
|
||||||
|
format="WAV",
|
||||||
|
subtype="FLOAT", # By default, a much lower quality WAV file is created !
|
||||||
|
) as file:
|
||||||
|
write_is_started_event.set()
|
||||||
|
while not write_stop_event.is_set():
|
||||||
|
try:
|
||||||
|
file.write(
|
||||||
|
queue.get(timeout=0.005)
|
||||||
|
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||||
|
queue.task_done()
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
write_is_started_event.clear()
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if self.is_connected:
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clear_queue(queue, join_queue: bool = False):
|
||||||
|
"""
|
||||||
|
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
queue.get_nowait()
|
||||||
|
queue.task_done()
|
||||||
|
except Empty:
|
||||||
|
if join_queue:
|
||||||
|
queue.join()
|
||||||
|
return
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configuration_touchlab import TouchLabSensorConfig
|
||||||
|
from .sensor_touchlab import TouchLabSensor
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..configs import MicrophoneConfig
|
||||||
|
|
||||||
|
|
||||||
|
@MicrophoneConfig.register_subclass("touchlab")
|
||||||
|
@dataclass
|
||||||
|
class TouchLabSensorConfig(MicrophoneConfig):
|
||||||
|
"""Configuration class for TouchLab tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||||
|
|
||||||
|
This class provides configuration options for TouchLab tactile sensors, including serial port, sample rate and channels.
|
||||||
|
|
||||||
|
Example configurations:
|
||||||
|
```python
|
||||||
|
# Basic configurations
|
||||||
|
TouchLabSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||||
|
TouchLabSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||||
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sensor_port: Serial port of the tactile sensor.
|
||||||
|
baud_rate: Baud rate of the tactile sensor.
|
||||||
|
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||||
|
channels: List of channel numbers to use for the tactile sensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sensor_port: str
|
||||||
|
baud_rate: int = 115_200
|
||||||
@@ -0,0 +1,469 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Provides the TouchLabSensor class for capturing tactile data from TouchLab tactile sensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from multiprocessing import (
|
||||||
|
Event as process_Event,
|
||||||
|
JoinableQueue as process_Queue,
|
||||||
|
Process,
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty
|
||||||
|
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from serial import Serial
|
||||||
|
from soundfile import SoundFile
|
||||||
|
|
||||||
|
from lerobot.utils.errors import (
|
||||||
|
DeviceAlreadyConnectedError,
|
||||||
|
DeviceAlreadyRecordingError,
|
||||||
|
DeviceNotConnectedError,
|
||||||
|
DeviceNotRecordingError,
|
||||||
|
)
|
||||||
|
from lerobot.utils.shared_array import SharedArray
|
||||||
|
|
||||||
|
from ..microphone import Microphone
|
||||||
|
from .configuration_touchlab import TouchLabSensorConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_SERIAL_READ_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
class TouchLabSensor(Microphone):
|
||||||
|
"""
|
||||||
|
The TouchLabSensor class handles all TouchLab tactile sensors.
|
||||||
|
|
||||||
|
A TouchLabSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||||
|
|
||||||
|
Example of usage:
|
||||||
|
```python
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import TouchLabSensorConfig
|
||||||
|
|
||||||
|
config = TouchLabSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||||
|
microphone = TouchLabSensor(config)
|
||||||
|
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording("some/output/file.wav")
|
||||||
|
...
|
||||||
|
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||||
|
...
|
||||||
|
microphone.stop_recording()
|
||||||
|
microphone.disconnect()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: TouchLabSensorConfig):
|
||||||
|
""" "
|
||||||
|
Initializes the TouchLabSensor instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration settings for the sensor.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Sensor port
|
||||||
|
self.sensor_port = config.sensor_port
|
||||||
|
|
||||||
|
# Baud rate
|
||||||
|
self.baud_rate = config.baud_rate
|
||||||
|
|
||||||
|
# Input audio recording process and events
|
||||||
|
self.record_process = None
|
||||||
|
self.record_stop_event = process_Event()
|
||||||
|
self.record_start_event = process_Event()
|
||||||
|
self.record_close_event = process_Event()
|
||||||
|
self.record_is_started_event = process_Event()
|
||||||
|
self.audio_callback_start_event = process_Event()
|
||||||
|
|
||||||
|
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||||
|
self.write_queue = process_Queue()
|
||||||
|
|
||||||
|
# SharedArray to store audio from the recording process.
|
||||||
|
self.read_shared_array = None
|
||||||
|
self.local_read_shared_array = None
|
||||||
|
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||||
|
self.write_thread = None
|
||||||
|
self.write_stop_event = None
|
||||||
|
self.write_is_started_event = None
|
||||||
|
|
||||||
|
self.logs = {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if the sensor is currently connected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the sensor is connected and ready to start recording,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
return self.record_process is not None and self.record_process.is_alive()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_recording(self) -> bool:
|
||||||
|
"""Check if the sensor is currently recording.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the sensor is recording, False otherwise.
|
||||||
|
"""
|
||||||
|
return self.record_is_started_event.is_set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_writing(self) -> bool:
|
||||||
|
"""Check if the sensor is currently writing to a file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the sensor is writing to a file, False otherwise.
|
||||||
|
"""
|
||||||
|
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_microphones() -> list[dict[str, Any]]:
|
||||||
|
"""Detects available sensors connected to the system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: A list of dictionaries,
|
||||||
|
where each dictionary contains information about a detected sensor.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Establish connection to the sensor.
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||||
|
|
||||||
|
# Create or reset queue and shared array
|
||||||
|
self.read_shared_array = SharedArray(
|
||||||
|
shape=(self.sample_rate * 10, len(self.channels)),
|
||||||
|
dtype=np.dtype("int16"),
|
||||||
|
)
|
||||||
|
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||||
|
self.write_queue = process_Queue()
|
||||||
|
|
||||||
|
# Reset events
|
||||||
|
self.record_start_event.clear()
|
||||||
|
self.record_stop_event.clear()
|
||||||
|
self.record_close_event.clear()
|
||||||
|
self.record_is_started_event.clear()
|
||||||
|
self.audio_callback_start_event.clear()
|
||||||
|
|
||||||
|
# Create and start an audio input stream with a recording callback
|
||||||
|
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||||
|
process_init_event = process_Event()
|
||||||
|
self.record_process = Process(
|
||||||
|
target=self._record_process,
|
||||||
|
args=(
|
||||||
|
self.sensor_port,
|
||||||
|
self.baud_rate,
|
||||||
|
self.channels,
|
||||||
|
process_init_event,
|
||||||
|
self.record_start_event,
|
||||||
|
self.record_stop_event,
|
||||||
|
self.record_close_event,
|
||||||
|
self.record_is_started_event,
|
||||||
|
self.audio_callback_start_event,
|
||||||
|
self.write_queue,
|
||||||
|
self.read_shared_array,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.record_process.daemon = True
|
||||||
|
self.record_process.start()
|
||||||
|
|
||||||
|
is_init = process_init_event.wait(
|
||||||
|
timeout=5.0
|
||||||
|
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||||
|
if not self.is_connected or not is_init:
|
||||||
|
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||||
|
|
||||||
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _record_process(
|
||||||
|
sensor_port,
|
||||||
|
baud_rate,
|
||||||
|
channels,
|
||||||
|
process_init_event,
|
||||||
|
record_start_event,
|
||||||
|
record_stop_event,
|
||||||
|
record_close_event,
|
||||||
|
record_is_started_event,
|
||||||
|
audio_callback_start_event,
|
||||||
|
write_queue,
|
||||||
|
read_shared_array,
|
||||||
|
) -> None:
|
||||||
|
channels_index = np.array(channels) - 1
|
||||||
|
local_read_shared_array = read_shared_array.get_local_array()
|
||||||
|
|
||||||
|
def tactile_callback(serial_connection):
|
||||||
|
"""
|
||||||
|
Parse the tactile data from the raw input data.
|
||||||
|
"""
|
||||||
|
buffer = serial_connection.readline()
|
||||||
|
|
||||||
|
if audio_callback_start_event.is_set():
|
||||||
|
strings = buffer.decode("utf8").split(",")
|
||||||
|
num_taxels = len(strings)
|
||||||
|
|
||||||
|
if num_taxels > 0 and num_taxels < MAX_SERIAL_READ_SIZE: # Make sure we didn't read rubbish
|
||||||
|
indata = np.empty((1, num_taxels))
|
||||||
|
for i in range(num_taxels):
|
||||||
|
indata[0, i] = int(strings[i])
|
||||||
|
|
||||||
|
write_queue.put_nowait(indata[:, channels_index])
|
||||||
|
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||||
|
|
||||||
|
process_init_event.set()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start_flag = record_start_event.wait(timeout=0.1)
|
||||||
|
if record_close_event.is_set():
|
||||||
|
break
|
||||||
|
elif not start_flag:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with Serial(sensor_port, baud_rate, timeout=0.5) as serial_connection:
|
||||||
|
serial_connection.flush()
|
||||||
|
record_is_started_event.set()
|
||||||
|
while not record_stop_event.is_set():
|
||||||
|
tactile_callback(serial_connection)
|
||||||
|
record_is_started_event.clear()
|
||||||
|
serial_connection.close()
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect the sensor and release any resources.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||||
|
|
||||||
|
if self.is_recording:
|
||||||
|
self.stop_recording()
|
||||||
|
|
||||||
|
self.record_close_event.set()
|
||||||
|
self.read_shared_array.delete()
|
||||||
|
self.write_queue.close()
|
||||||
|
self.record_process.join()
|
||||||
|
|
||||||
|
if self.is_connected:
|
||||||
|
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||||
|
|
||||||
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
def start_recording(
|
||||||
|
self,
|
||||||
|
output_file: str | Path | None = None,
|
||||||
|
multiprocessing: bool | None = False,
|
||||||
|
overwrite: bool | None = True,
|
||||||
|
barrier: Barrier | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Start recording tactile data from the sensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Optional path to save the recorded tactile data.
|
||||||
|
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||||
|
overwrite: If True, overwrites existing files at output_file path.
|
||||||
|
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||||
|
if self.is_recording:
|
||||||
|
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||||
|
|
||||||
|
# Reset queue and shared memory
|
||||||
|
self.read_shared_array.reset()
|
||||||
|
self._clear_queue(self.write_queue)
|
||||||
|
|
||||||
|
# Reset stop event
|
||||||
|
self.record_stop_event.clear()
|
||||||
|
|
||||||
|
# Write recordings into a file if output_file is provided
|
||||||
|
if output_file is not None:
|
||||||
|
output_file = Path(output_file)
|
||||||
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if output_file.exists():
|
||||||
|
if overwrite:
|
||||||
|
output_file.unlink()
|
||||||
|
else:
|
||||||
|
raise FileExistsError(
|
||||||
|
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||||
|
)
|
||||||
|
|
||||||
|
if multiprocessing:
|
||||||
|
self.write_stop_event = process_Event()
|
||||||
|
self.write_is_started_event = process_Event()
|
||||||
|
self.write_thread = Process(
|
||||||
|
target=TouchLabSensor._write_loop,
|
||||||
|
args=(
|
||||||
|
self.write_queue,
|
||||||
|
self.write_stop_event,
|
||||||
|
self.write_is_started_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.write_stop_event = thread_Event()
|
||||||
|
self.write_is_started_event = thread_Event()
|
||||||
|
self.write_thread = Thread(
|
||||||
|
target=TouchLabSensor._write_loop,
|
||||||
|
args=(
|
||||||
|
self.write_queue,
|
||||||
|
self.write_stop_event,
|
||||||
|
self.write_is_started_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.write_thread.daemon = True
|
||||||
|
self.write_thread.start()
|
||||||
|
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||||
|
|
||||||
|
self.record_start_event.set() # Start the input audio stream process
|
||||||
|
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||||
|
|
||||||
|
if barrier is not None:
|
||||||
|
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||||
|
|
||||||
|
self.audio_callback_start_event.set()
|
||||||
|
|
||||||
|
if not self.is_recording:
|
||||||
|
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||||
|
if output_file is not None and not self.is_writing:
|
||||||
|
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||||
|
|
||||||
|
def _read(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Thread/Process-safe callback to read available audio data
|
||||||
|
"""
|
||||||
|
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||||
|
|
||||||
|
def read(self) -> np.ndarray:
|
||||||
|
"""Capture and return a single audio chunk from the sensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Captured audio chunk as a numpy array.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||||
|
if not self.is_recording:
|
||||||
|
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
tactile_readings = self._read()
|
||||||
|
|
||||||
|
# log the number of seconds it took to read the audio chunk
|
||||||
|
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# log the utc time at which the audio chunk was received
|
||||||
|
self.logs["timestamp_utc"] = time.perf_counter()
|
||||||
|
|
||||||
|
return tactile_readings
|
||||||
|
|
||||||
|
def _read_loop(self) -> None:
|
||||||
|
"""Internal loop run by the background thread for asynchronous reading."""
|
||||||
|
|
||||||
|
def stop_recording(self) -> None:
|
||||||
|
"""Stop recording audio from the sensor."""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||||
|
if not self.is_recording:
|
||||||
|
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||||
|
|
||||||
|
self.audio_callback_start_event.clear()
|
||||||
|
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||||
|
self.record_stop_event.set()
|
||||||
|
|
||||||
|
self.read_shared_array.reset()
|
||||||
|
self._clear_queue(self.write_queue, join_queue=True)
|
||||||
|
|
||||||
|
if self.is_writing:
|
||||||
|
self.write_stop_event.set()
|
||||||
|
self.write_thread.join()
|
||||||
|
|
||||||
|
timeout = 1.0
|
||||||
|
while self.is_recording and timeout > 0:
|
||||||
|
time.sleep(0.01)
|
||||||
|
timeout -= 0.01
|
||||||
|
|
||||||
|
if self.is_recording:
|
||||||
|
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||||
|
if self.is_writing:
|
||||||
|
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if self.is_connected:
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clear_queue(queue, join_queue: bool = False):
|
||||||
|
"""
|
||||||
|
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
queue.get_nowait()
|
||||||
|
queue.task_done()
|
||||||
|
except Empty:
|
||||||
|
if join_queue:
|
||||||
|
queue.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_loop(
|
||||||
|
queue,
|
||||||
|
write_stop_event: Event,
|
||||||
|
write_is_started_event: Event,
|
||||||
|
sample_rate: int,
|
||||||
|
channels: list[int],
|
||||||
|
output_file: Path,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Thread/Process-safe loop to write audio data into a file.
|
||||||
|
"""
|
||||||
|
# Can only be run on a single process/thread for file writing safety
|
||||||
|
with SoundFile(
|
||||||
|
output_file,
|
||||||
|
mode="w",
|
||||||
|
samplerate=sample_rate,
|
||||||
|
channels=len(channels),
|
||||||
|
format="WAV",
|
||||||
|
subtype="PCM_16", # Subtype for int16 values
|
||||||
|
) as file:
|
||||||
|
write_is_started_event.set()
|
||||||
|
while not write_stop_event.is_set():
|
||||||
|
try:
|
||||||
|
file.write(
|
||||||
|
queue.get(timeout=0.005)
|
||||||
|
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||||
|
queue.task_done()
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
write_is_started_event.clear()
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from multiprocessing import Barrier
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from .configs import MicrophoneConfig
|
||||||
|
from .microphone import Microphone
|
||||||
|
|
||||||
|
|
||||||
|
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig]) -> dict[str, Microphone]:
|
||||||
|
microphones = {}
|
||||||
|
|
||||||
|
for key, cfg in microphone_configs.items():
|
||||||
|
if cfg.type == "portaudio":
|
||||||
|
from .portaudio import PortAudioMicrophone
|
||||||
|
|
||||||
|
microphones[key] = PortAudioMicrophone(cfg)
|
||||||
|
elif cfg.type == "touchlab":
|
||||||
|
from .touchlab import TouchLabSensor
|
||||||
|
|
||||||
|
microphones[key] = TouchLabSensor(cfg)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
|
return microphones
|
||||||
|
|
||||||
|
|
||||||
|
def async_microphones_start_recording(
|
||||||
|
microphones: dict[str, Microphone],
|
||||||
|
output_files: list[str | None] | None = None,
|
||||||
|
multiprocessing: bool = False,
|
||||||
|
overwrite: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Starts recording on multiple microphones asynchronously to avoid delays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
microphones: A dictionary of microphones.
|
||||||
|
output_files: A list of output files.
|
||||||
|
multiprocessing: If True, enables multiprocessing for recording.
|
||||||
|
overwrite: If True, overwrites existing files at output_file path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_recording_threads = []
|
||||||
|
if output_files is None:
|
||||||
|
output_files = [None] * len(microphones)
|
||||||
|
|
||||||
|
barrier = Barrier(len(microphones))
|
||||||
|
|
||||||
|
for microphone, output_file in zip(microphones.values(), output_files, strict=False):
|
||||||
|
start_recording_threads.append(
|
||||||
|
Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite, barrier))
|
||||||
|
)
|
||||||
|
|
||||||
|
for thread in start_recording_threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in start_recording_threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
def async_microphones_stop_recording(microphones: dict[str, Microphone]) -> None:
|
||||||
|
"""
|
||||||
|
Stops recording on multiple microphones asynchronously to avoid delays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
microphones: A dictionary of microphones.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stop_recording_threads = []
|
||||||
|
|
||||||
|
for microphone in microphones.values():
|
||||||
|
stop_recording_threads.append(Thread(target=microphone.stop_recording))
|
||||||
|
|
||||||
|
for thread in stop_recording_threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in stop_recording_threads:
|
||||||
|
thread.join()
|
||||||
@@ -89,6 +89,7 @@ class ACTConfig(PreTrainedConfig):
|
|||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
"AUDIO": NormalizationMode.IDENTITY,
|
||||||
"STATE": NormalizationMode.MEAN_STD,
|
"STATE": NormalizationMode.MEAN_STD,
|
||||||
"ACTION": NormalizationMode.MEAN_STD,
|
"ACTION": NormalizationMode.MEAN_STD,
|
||||||
}
|
}
|
||||||
@@ -99,6 +100,10 @@ class ACTConfig(PreTrainedConfig):
|
|||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
replace_final_stride_with_dilation: int = False
|
replace_final_stride_with_dilation: int = False
|
||||||
|
# Audio backbone.
|
||||||
|
audio_backbone: str = vision_backbone
|
||||||
|
pretrained_backbone_weights_audio: str | None = None
|
||||||
|
replace_final_stride_with_dilation_audio: int = False
|
||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
pre_norm: bool = False
|
pre_norm: bool = False
|
||||||
dim_model: int = 512
|
dim_model: int = 512
|
||||||
@@ -161,8 +166,10 @@ class ACTConfig(PreTrainedConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
if not self.image_features and not self.env_state_feature:
|
if not (self.image_features or self.audio_features) and not self.env_state_feature:
|
||||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
raise ValueError(
|
||||||
|
"You must provide at least one image/audio or the environment state among the inputs."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> None:
|
def observation_delta_indices(self) -> None:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||||||
|
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(PreTrainedPolicy):
|
class ACTPolicy(PreTrainedPolicy):
|
||||||
@@ -106,6 +106,8 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
"""
|
"""
|
||||||
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
|
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
|
||||||
|
|
||||||
|
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||||
|
# we are ensembling over.
|
||||||
if self.config.temporal_ensemble_coeff is not None:
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
actions = self.predict_action_chunk(batch)
|
actions = self.predict_action_chunk(batch)
|
||||||
action = self.temporal_ensembler.update(actions)
|
action = self.temporal_ensembler.update(actions)
|
||||||
@@ -331,12 +333,26 @@ class ACT(nn.Module):
|
|||||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||||
|
|
||||||
|
# Backbone for audio feature extraction.
|
||||||
|
if self.config.audio_features:
|
||||||
|
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
|
||||||
|
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
|
||||||
|
weights=config.pretrained_backbone_weights_audio,
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
|
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||||
|
# feature map).
|
||||||
|
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||||
|
self.audio_backbone = IntermediateLayerGetter(
|
||||||
|
audio_backbone_model, return_layers={"layer4": "feature_map"}
|
||||||
|
)
|
||||||
|
|
||||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||||
self.encoder = ACTEncoder(config)
|
self.encoder = ACTEncoder(config)
|
||||||
self.decoder = ACTDecoder(config)
|
self.decoder = ACTDecoder(config)
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
# [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)].
|
||||||
if self.config.robot_state_feature:
|
if self.config.robot_state_feature:
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
self.config.robot_state_feature.shape[0], config.dim_model
|
self.config.robot_state_feature.shape[0], config.dim_model
|
||||||
@@ -350,6 +366,10 @@ class ACT(nn.Module):
|
|||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
|
if self.config.audio_features:
|
||||||
|
self.encoder_audio_feat_input_proj = nn.Conv2d(
|
||||||
|
audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
n_1d_tokens = 1 # for the latent
|
n_1d_tokens = 1 # for the latent
|
||||||
if self.config.robot_state_feature:
|
if self.config.robot_state_feature:
|
||||||
@@ -359,6 +379,8 @@ class ACT(nn.Module):
|
|||||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
if self.config.audio_features:
|
||||||
|
self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||||
@@ -483,6 +505,21 @@ class ACT(nn.Module):
|
|||||||
encoder_in_tokens.extend(list(cam_features))
|
encoder_in_tokens.extend(list(cam_features))
|
||||||
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
||||||
|
|
||||||
|
if self.config.audio_features:
|
||||||
|
for audio in batch[OBS_AUDIO]:
|
||||||
|
audio_features = self.audio_backbone(audio)["feature_map"]
|
||||||
|
audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to(
|
||||||
|
dtype=audio_features.dtype
|
||||||
|
)
|
||||||
|
audio_features = self.encoder_audio_feat_input_proj(audio_features)
|
||||||
|
|
||||||
|
# Rearrange features to (sequence, batch, dim).
|
||||||
|
audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c")
|
||||||
|
audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c")
|
||||||
|
|
||||||
|
encoder_in_tokens.extend(list(audio_features))
|
||||||
|
encoder_in_pos_embed.extend(list(audio_pos_embed))
|
||||||
|
|
||||||
# Stack all tokens along the sequence dimension.
|
# Stack all tokens along the sequence dimension.
|
||||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||||
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
||||||
|
|||||||
@@ -17,9 +17,11 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
|
AudioProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
@@ -63,6 +65,15 @@ def make_act_pre_post_processors(
|
|||||||
stats=dataset_stats,
|
stats=dataset_stats,
|
||||||
device=config.device,
|
device=config.device,
|
||||||
),
|
),
|
||||||
|
AudioProcessorStep(
|
||||||
|
output_height=224,
|
||||||
|
output_width=224,
|
||||||
|
output_channels=3,
|
||||||
|
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
|
input_sample_rate=48000,
|
||||||
|
intermediate_sample_rate=16000,
|
||||||
|
n_fft=1024,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
output_steps = [
|
output_steps = [
|
||||||
UnnormalizerProcessorStep(
|
UnnormalizerProcessorStep(
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
|
|||||||
This function takes a dictionary of NumPy arrays, performs necessary
|
This function takes a dictionary of NumPy arrays, performs necessary
|
||||||
preprocessing, and prepares it for model inference. The steps include:
|
preprocessing, and prepares it for model inference. The steps include:
|
||||||
1. Converting NumPy arrays to PyTorch tensors.
|
1. Converting NumPy arrays to PyTorch tensors.
|
||||||
2. Normalizing and permuting image data (if any).
|
2. Normalizing and permuting image data and audio data (if any).
|
||||||
3. Adding a batch dimension to each tensor.
|
3. Adding a batch dimension to each tensor.
|
||||||
4. Moving all tensors to the specified compute device.
|
4. Moving all tensors to the specified compute device.
|
||||||
5. Adding task and robot type information to the dictionary.
|
5. Adding task and robot type information to the dictionary.
|
||||||
@@ -129,6 +129,9 @@ def prepare_observation_for_inference(
|
|||||||
if "image" in name:
|
if "image" in name:
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
elif "audio" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32)
|
||||||
|
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
observation[name] = observation[name].to(device)
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from lerobot.types import (
|
|||||||
TransitionKey,
|
TransitionKey,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .audio_processor import AudioProcessorStep
|
||||||
from .batch_processor import AddBatchDimensionProcessorStep
|
from .batch_processor import AddBatchDimensionProcessorStep
|
||||||
from .converters import (
|
from .converters import (
|
||||||
batch_to_transition,
|
batch_to_transition,
|
||||||
@@ -88,6 +89,7 @@ __all__ = [
|
|||||||
"ActionProcessorStep",
|
"ActionProcessorStep",
|
||||||
"AddTeleopActionAsComplimentaryDataStep",
|
"AddTeleopActionAsComplimentaryDataStep",
|
||||||
"AddTeleopEventsAsInfoStep",
|
"AddTeleopEventsAsInfoStep",
|
||||||
|
"AudioProcessorStep",
|
||||||
"ComplementaryDataProcessorStep",
|
"ComplementaryDataProcessorStep",
|
||||||
"batch_to_transition",
|
"batch_to_transition",
|
||||||
"create_transition",
|
"create_transition",
|
||||||
|
|||||||
@@ -0,0 +1,130 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torchaudio.functional import amplitude_to_DB
|
||||||
|
from torchaudio.transforms import MelSpectrogram, Resample
|
||||||
|
from torchvision.transforms import Compose, Lambda, Resize
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||||
|
from lerobot.utils.constants import OBS_AUDIO
|
||||||
|
|
||||||
|
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="audio_processor")
|
||||||
|
class AudioProcessorStep(ObservationProcessorStep):
|
||||||
|
"""
|
||||||
|
Processes audio waveform data into a mel-spectrogram image representation.
|
||||||
|
|
||||||
|
**Audio Processing:**
|
||||||
|
- Averages waveform data over all channels.
|
||||||
|
- Resamples the waveform to 16kHz.
|
||||||
|
- Converts the waveform to a mel-spectrogram.
|
||||||
|
- Converts the mel-spectrogram to decibels.
|
||||||
|
- Resizes the mel-spectrogram to 224×224.
|
||||||
|
- Converts the mel-spectrogram to a channel-first, normalized tensor.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
output_height: Height of the output mel-spectrogram image in pixels.
|
||||||
|
output_width: Width of the output mel-spectrogram image in pixels.
|
||||||
|
output_channels: Number of channels in the output image (3 for RGB-like format).
|
||||||
|
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
|
||||||
|
input_sample_rate: Original sample rate of the input audio in Hz.
|
||||||
|
|
||||||
|
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
|
||||||
|
Downsampling improves the temporal resolution but reduces the frequency range.
|
||||||
|
n_fft: Size of the FFT window for spectrogram computation.
|
||||||
|
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
|
||||||
|
|
||||||
|
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
|
||||||
|
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
|
||||||
|
n_mels: Number of mel filter banks, computed automatically to match the output_height.
|
||||||
|
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
|
||||||
|
mel_spectrogram_transform: The complete audio processing pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_height: int = 224
|
||||||
|
output_width: int = 224
|
||||||
|
output_channels: int = 3
|
||||||
|
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
|
||||||
|
|
||||||
|
input_sample_rate: int = 48000
|
||||||
|
intermediate_sample_rate: int = 16000
|
||||||
|
|
||||||
|
n_fft: int = 1024
|
||||||
|
|
||||||
|
# Parameters computed from other parameters at initialization
|
||||||
|
hop_length: int = field(init=False)
|
||||||
|
n_mels: int = field(init=False)
|
||||||
|
mel_spectrogram_transform: Compose = field(init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.hop_length = int(
|
||||||
|
self.intermediate_sample_rate * self.input_audio_chunk_duration
|
||||||
|
- self.n_fft // self.output_width
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
self.n_mels = self.output_height
|
||||||
|
|
||||||
|
self.mel_spectrogram_transform = Compose(
|
||||||
|
[
|
||||||
|
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
|
||||||
|
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
|
||||||
|
MelSpectrogram(
|
||||||
|
sample_rate=self.intermediate_sample_rate,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
power=2, # Power spectrum
|
||||||
|
),
|
||||||
|
Lambda(
|
||||||
|
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
|
||||||
|
), # Convert to decibels
|
||||||
|
Resize(
|
||||||
|
(self.output_height, self.output_width)
|
||||||
|
), # Resize spectrogram to output_height×output_width
|
||||||
|
Lambda(
|
||||||
|
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
|
||||||
|
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
"""
|
||||||
|
Processes audio data contained in the provided observation.
|
||||||
|
"""
|
||||||
|
processed_obs = observation.copy()
|
||||||
|
|
||||||
|
# Process single audio observation
|
||||||
|
if OBS_AUDIO in processed_obs:
|
||||||
|
audio_data = processed_obs[OBS_AUDIO]
|
||||||
|
if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples
|
||||||
|
processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data)
|
||||||
|
|
||||||
|
# Process multiple audio observations
|
||||||
|
for key, value in processed_obs.items():
|
||||||
|
if (
|
||||||
|
key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3
|
||||||
|
): # Batch, Channels, Samples
|
||||||
|
processed_obs[key] = self.mel_spectrogram_transform(value)
|
||||||
|
|
||||||
|
return processed_obs
|
||||||
|
|
||||||
|
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
return self._process_observation(observation)
|
||||||
@@ -25,8 +25,7 @@ from dataclasses import dataclass, field
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.types import EnvTransition, PolicyAction
|
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
|
||||||
|
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
ComplementaryDataProcessorStep,
|
ComplementaryDataProcessorStep,
|
||||||
@@ -36,6 +35,7 @@ from .pipeline import (
|
|||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
TransitionKey,
|
TransitionKey,
|
||||||
)
|
)
|
||||||
|
from lerobot.types import PolicyAction, EnvTransition
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -88,6 +88,8 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
|||||||
- State vectors (1D tensors).
|
- State vectors (1D tensors).
|
||||||
- Single images (3D tensors).
|
- Single images (3D tensors).
|
||||||
- Dictionaries of multiple images (3D tensors).
|
- Dictionaries of multiple images (3D tensors).
|
||||||
|
- Single audio waveforms (2D tensors).
|
||||||
|
- Dictionaries of multiple audio waveforms (2D tensors).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
@@ -117,6 +119,18 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
|||||||
for key, value in observation.items():
|
for key, value in observation.items():
|
||||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||||
observation[key] = value.unsqueeze(0)
|
observation[key] = value.unsqueeze(0)
|
||||||
|
|
||||||
|
# Process single audio observation - add batch dim if 2D
|
||||||
|
if OBS_AUDIO in observation:
|
||||||
|
audio_value = observation[OBS_AUDIO]
|
||||||
|
if isinstance(audio_value, Tensor) and audio_value.dim() == 2:
|
||||||
|
observation[OBS_AUDIO] = audio_value.unsqueeze(0)
|
||||||
|
|
||||||
|
# Process multiple audio observations - add batch dim if 2D
|
||||||
|
for key, value in observation.items():
|
||||||
|
if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2:
|
||||||
|
observation[key] = value.unsqueeze(0)
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
def transform_features(
|
def transform_features(
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Specifying '{attr}' is required for the camera to be used in a robot"
|
f"Specifying '{attr}' is required for the camera to be used in a robot"
|
||||||
)
|
)
|
||||||
|
if hasattr(self, "microphones") and self.microphones:
|
||||||
|
for _, config in self.microphones.items():
|
||||||
|
for attr in ["sample_rate", "channels"]:
|
||||||
|
if getattr(config, attr) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Specifying '{attr}' is required for the microphone to be used in a robot"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
from lerobot.cameras import CameraConfig
|
||||||
|
from lerobot.microphones import MicrophoneConfig
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
|
|
||||||
@@ -35,5 +36,8 @@ class KochFollowerConfig(RobotConfig):
|
|||||||
# cameras
|
# cameras
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# microphones
|
||||||
|
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
# Set to `True` for backward compatibility with previous policies/dataset
|
# Set to `True` for backward compatibility with previous policies/dataset
|
||||||
use_degrees: bool = False
|
use_degrees: bool = False
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import time
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.motors.dynamixel import (
|
from lerobot.motors.dynamixel import (
|
||||||
DynamixelMotorsBus,
|
DynamixelMotorsBus,
|
||||||
@@ -61,6 +62,7 @@ class KochFollower(Robot):
|
|||||||
calibration=self.calibration,
|
calibration=self.calibration,
|
||||||
)
|
)
|
||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
self.microphones = make_microphones_from_configs(config.microphones)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _motors_ft(self) -> dict[str, type]:
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
@@ -72,9 +74,16 @@ class KochFollower(Robot):
|
|||||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _microphones_ft(self) -> dict[str, tuple]:
|
||||||
|
return {
|
||||||
|
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||||
|
for mic in self.microphones
|
||||||
|
}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
return {**self._motors_ft, **self._cameras_ft}
|
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
@@ -82,7 +91,11 @@ class KochFollower(Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
return (
|
||||||
|
self.bus.is_connected
|
||||||
|
and all(cam.is_connected for cam in self.cameras.values())
|
||||||
|
and all(mic.is_connected for mic in self.microphones.values())
|
||||||
|
)
|
||||||
|
|
||||||
@check_if_already_connected
|
@check_if_already_connected
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
@@ -101,6 +114,9 @@ class KochFollower(Robot):
|
|||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.connect()
|
cam.connect()
|
||||||
|
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
mic.connect()
|
||||||
|
|
||||||
self.configure()
|
self.configure()
|
||||||
logger.info(f"{self} connected.")
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
@@ -197,6 +213,13 @@ class KochFollower(Robot):
|
|||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
# Read audio frames from microphones
|
||||||
|
for mic_key, mic in self.microphones.items():
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict[mic_key] = mic.read()
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
@@ -232,5 +255,7 @@ class KochFollower(Robot):
|
|||||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.disconnect()
|
cam.disconnect()
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
mic.disconnect()
|
||||||
|
|
||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
|
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.microphones import MicrophoneConfig
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
|
|
||||||
@@ -45,6 +46,8 @@ class LeKiwiConfig(RobotConfig):
|
|||||||
|
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||||
|
|
||||||
|
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
# Set to `True` for backward compatibility with previous policies/dataset
|
# Set to `True` for backward compatibility with previous policies/dataset
|
||||||
use_degrees: bool = False
|
use_degrees: bool = False
|
||||||
|
|
||||||
@@ -92,5 +95,7 @@ class LeKiwiClientConfig(RobotConfig):
|
|||||||
|
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||||
|
|
||||||
|
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
polling_timeout_ms: int = 15
|
polling_timeout_ms: int = 15
|
||||||
connect_timeout_s: int = 5
|
connect_timeout_s: int = 5
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.motors.feetech import (
|
from lerobot.motors.feetech import (
|
||||||
FeetechMotorsBus,
|
FeetechMotorsBus,
|
||||||
@@ -73,6 +74,7 @@ class LeKiwi(Robot):
|
|||||||
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
|
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
|
||||||
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
|
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
|
||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
self.microphones = make_microphones_from_configs(config.microphones)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _state_ft(self) -> dict[str, type]:
|
def _state_ft(self) -> dict[str, type]:
|
||||||
@@ -97,9 +99,16 @@ class LeKiwi(Robot):
|
|||||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _microphones_ft(self) -> dict[str, tuple]:
|
||||||
|
return {
|
||||||
|
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||||
|
for mic in self.microphones
|
||||||
|
}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
return {**self._state_ft, **self._cameras_ft}
|
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
@@ -107,7 +116,11 @@ class LeKiwi(Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
return (
|
||||||
|
self.bus.is_connected
|
||||||
|
and all(cam.is_connected for cam in self.cameras.values())
|
||||||
|
and all(mic.is_connected for mic in self.microphones.values())
|
||||||
|
)
|
||||||
|
|
||||||
@check_if_already_connected
|
@check_if_already_connected
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
@@ -121,6 +134,9 @@ class LeKiwi(Robot):
|
|||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.connect()
|
cam.connect()
|
||||||
|
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
mic.connect()
|
||||||
|
|
||||||
self.configure()
|
self.configure()
|
||||||
logger.info(f"{self} connected.")
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
@@ -364,6 +380,13 @@ class LeKiwi(Robot):
|
|||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
# Read audio frames from microphones
|
||||||
|
for mic_key, mic in self.microphones.items():
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict[mic_key] = mic.read()
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
@@ -413,5 +436,7 @@ class LeKiwi(Robot):
|
|||||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.disconnect()
|
cam.disconnect()
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
mic.disconnect()
|
||||||
|
|
||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -58,8 +59,9 @@ class LeKiwiClient(Robot):
|
|||||||
self.zmq_observation_socket = None
|
self.zmq_observation_socket = None
|
||||||
|
|
||||||
self.last_frames = {}
|
self.last_frames = {}
|
||||||
|
|
||||||
self.last_remote_state = {}
|
self.last_remote_state = {}
|
||||||
|
self.last_frame_timestamp = None
|
||||||
|
self.last_frame_delay = 0.0
|
||||||
|
|
||||||
# Define three speed levels and a current index
|
# Define three speed levels and a current index
|
||||||
self.speed_levels = [
|
self.speed_levels = [
|
||||||
@@ -97,9 +99,13 @@ class LeKiwiClient(Robot):
|
|||||||
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
|
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
|
||||||
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _microphones_ft(self) -> dict[str, tuple]:
|
||||||
|
return {name: (cfg.sample_rate, cfg.channels) for name, cfg in self.config.microphones.items()}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
return {**self._state_ft, **self._cameras_ft}
|
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
@@ -135,6 +141,7 @@ class LeKiwiClient(Robot):
|
|||||||
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||||
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
||||||
|
|
||||||
|
self.last_frame_timestamp = perf_counter()
|
||||||
self._is_connected = True
|
self._is_connected = True
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
@@ -167,6 +174,8 @@ class LeKiwiClient(Robot):
|
|||||||
if last_msg is None:
|
if last_msg is None:
|
||||||
logging.warning("Poller indicated data, but failed to retrieve message.")
|
logging.warning("Poller indicated data, but failed to retrieve message.")
|
||||||
|
|
||||||
|
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
|
||||||
|
self.last_frame_timestamp = perf_counter()
|
||||||
return last_msg
|
return last_msg
|
||||||
|
|
||||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||||
@@ -203,14 +212,16 @@ class LeKiwiClient(Robot):
|
|||||||
|
|
||||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||||
|
|
||||||
# Decode images
|
# Decode images and audio data
|
||||||
current_frames: dict[str, np.ndarray] = {}
|
current_frames: dict[str, np.ndarray] = {}
|
||||||
for cam_name, image_b64 in observation.items():
|
for frame_name, frame_data in observation.items():
|
||||||
if cam_name not in self._cameras_ft:
|
if frame_name in self._cameras_ft:
|
||||||
continue
|
image = self._decode_image_from_b64(frame_data)
|
||||||
frame = self._decode_image_from_b64(image_b64)
|
if image is not None:
|
||||||
if frame is not None:
|
current_frames[frame_name] = image
|
||||||
current_frames[cam_name] = frame
|
elif frame_name in self._microphones_ft:
|
||||||
|
if frame_data is not None:
|
||||||
|
current_frames[frame_name] = frame_data
|
||||||
|
|
||||||
return current_frames, obs_dict
|
return current_frames, obs_dict
|
||||||
|
|
||||||
@@ -254,17 +265,27 @@ class LeKiwiClient(Robot):
|
|||||||
"""
|
"""
|
||||||
Capture observations from the remote robot: current follower arm positions,
|
Capture observations from the remote robot: current follower arm positions,
|
||||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frames, obs_dict = self._get_data()
|
frames, obs_dict = self._get_data()
|
||||||
|
|
||||||
# Loop over each configured camera
|
# Loop over each configured camera and microphone
|
||||||
for cam_name, frame in frames.items():
|
for frame_name, frame_data in frames.items():
|
||||||
if frame is None:
|
if frame_data is None:
|
||||||
logging.warning("Frame is None")
|
if frame_name in self._cameras_ft:
|
||||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
logging.warning("Image frame is None")
|
||||||
obs_dict[cam_name] = frame
|
image = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||||
|
obs_dict[frame_name] = image
|
||||||
|
elif frame_name in self._microphones_ft:
|
||||||
|
logging.warning("Audio frame is None")
|
||||||
|
obs_dict[frame_name] = np.zeros(
|
||||||
|
(
|
||||||
|
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
|
||||||
|
self._microphones_ft[frame_name][1],
|
||||||
|
),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
from lerobot.cameras import CameraConfig
|
||||||
|
from lerobot.microphones import MicrophoneConfig
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
|
|
||||||
@@ -38,6 +39,9 @@ class SOFollowerConfig:
|
|||||||
# cameras
|
# cameras
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# microphones
|
||||||
|
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
# Set to `True` for backward compatibility with previous policies/dataset
|
# Set to `True` for backward compatibility with previous policies/dataset
|
||||||
use_degrees: bool = True
|
use_degrees: bool = True
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import time
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.motors.feetech import (
|
from lerobot.motors.feetech import (
|
||||||
FeetechMotorsBus,
|
FeetechMotorsBus,
|
||||||
@@ -61,6 +62,7 @@ class SOFollower(Robot):
|
|||||||
calibration=self.calibration,
|
calibration=self.calibration,
|
||||||
)
|
)
|
||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
self.microphones = make_microphones_from_configs(config.microphones)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _motors_ft(self) -> dict[str, type]:
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
@@ -72,9 +74,16 @@ class SOFollower(Robot):
|
|||||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _microphones_ft(self) -> dict[str, tuple]:
|
||||||
|
return {
|
||||||
|
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||||
|
for mic in self.microphones
|
||||||
|
}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
return {**self._motors_ft, **self._cameras_ft}
|
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
@@ -82,7 +91,11 @@ class SOFollower(Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
return (
|
||||||
|
self.bus.is_connected
|
||||||
|
and all(cam.is_connected for cam in self.cameras.values())
|
||||||
|
and all(mic.is_connected for mic in self.microphones.values())
|
||||||
|
)
|
||||||
|
|
||||||
@check_if_already_connected
|
@check_if_already_connected
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
@@ -101,6 +114,9 @@ class SOFollower(Robot):
|
|||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.connect()
|
cam.connect()
|
||||||
|
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
mic.connect()
|
||||||
|
|
||||||
self.configure()
|
self.configure()
|
||||||
logger.info(f"{self} connected.")
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
@@ -190,6 +206,13 @@ class SOFollower(Robot):
|
|||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
# Read audio frames from microphones
|
||||||
|
for mic_key, mic in self.microphones.items():
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict[mic_key] = mic.read()
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
@@ -225,6 +248,8 @@ class SOFollower(Robot):
|
|||||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.disconnect()
|
cam.disconnect()
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
mic.disconnect()
|
||||||
|
|
||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ from lerobot.datasets.utils import (
|
|||||||
flatten_dict,
|
flatten_dict,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
@@ -318,12 +318,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
|||||||
|
|
||||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||||
|
|
||||||
# Check if adding this episode would exceed the limit
|
# Check if adding this episode would exceed the limit
|
||||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||||
concatenate_video_files(
|
concatenate_media_files(
|
||||||
paths_to_cat,
|
paths_to_cat,
|
||||||
new_root
|
new_root
|
||||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||||
@@ -359,7 +359,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
|||||||
|
|
||||||
# Write remaining videos if any
|
# Write remaining videos if any
|
||||||
if paths_to_cat:
|
if paths_to_cat:
|
||||||
concatenate_video_files(
|
concatenate_media_files(
|
||||||
paths_to_cat,
|
paths_to_cat,
|
||||||
new_root
|
new_root
|
||||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||||
@@ -402,7 +402,12 @@ def generate_episode_metadata_dict(
|
|||||||
if len(ep_ids_set) != 1:
|
if len(ep_ids_set) != 1:
|
||||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||||
|
|
||||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
ep_dict = {
|
||||||
|
**ep_metadata,
|
||||||
|
**ep_video,
|
||||||
|
**ep_legacy_metadata,
|
||||||
|
**flatten_dict({"stats": ep_stats}),
|
||||||
|
}
|
||||||
ep_dict["meta/episodes/chunk_index"] = 0
|
ep_dict["meta/episodes/chunk_index"] = 0
|
||||||
ep_dict["meta/episodes/file_index"] = 0
|
ep_dict["meta/episodes/file_index"] = 0
|
||||||
yield ep_dict
|
yield ep_dict
|
||||||
@@ -423,7 +428,10 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
|||||||
|
|
||||||
ds_episodes = Dataset.from_generator(
|
ds_episodes = Dataset.from_generator(
|
||||||
lambda: generate_episode_metadata_dict(
|
lambda: generate_episode_metadata_dict(
|
||||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
episodes_legacy_metadata,
|
||||||
|
episodes_metadata,
|
||||||
|
episodes_stats,
|
||||||
|
episodes_video_metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
write_episodes(ds_episodes, new_root)
|
write_episodes(ds_episodes, new_root)
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ import draccus
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
|
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||||
|
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
|
|||||||
@@ -69,11 +69,14 @@ lerobot-record \
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from copy import copy
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.cameras import ( # noqa: F401
|
from lerobot.cameras import ( # noqa: F401
|
||||||
CameraConfig, # noqa: F401
|
CameraConfig, # noqa: F401
|
||||||
)
|
)
|
||||||
@@ -87,7 +90,20 @@ from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_
|
|||||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
|
from lerobot.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
|
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||||
|
)
|
||||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||||
|
from lerobot.microphones import (
|
||||||
|
MicrophoneConfig, # noqa: F401
|
||||||
|
)
|
||||||
|
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||||
|
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||||
|
from lerobot.microphones.utils import (
|
||||||
|
async_microphones_start_recording,
|
||||||
|
async_microphones_stop_recording,
|
||||||
|
)
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import make_robot_action
|
from lerobot.policies.utils import make_robot_action
|
||||||
@@ -131,6 +147,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
unitree_g1,
|
unitree_g1,
|
||||||
)
|
)
|
||||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||||
|
from lerobot.utils.audio_utils import rolling_vstack
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.control_utils import (
|
from lerobot.utils.control_utils import (
|
||||||
init_keyboard_listener,
|
init_keyboard_listener,
|
||||||
@@ -300,6 +317,13 @@ def record_loop(
|
|||||||
display_data: bool = False,
|
display_data: bool = False,
|
||||||
display_compressed_images: bool = False,
|
display_compressed_images: bool = False,
|
||||||
):
|
):
|
||||||
|
if display_data:
|
||||||
|
init_rerun(
|
||||||
|
session_name="recording",
|
||||||
|
robot=robot,
|
||||||
|
reset_time=True,
|
||||||
|
)
|
||||||
|
|
||||||
if dataset is not None and dataset.fps != fps:
|
if dataset is not None and dataset.fps != fps:
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||||
|
|
||||||
@@ -334,6 +358,36 @@ def record_loop(
|
|||||||
preprocessor.reset()
|
preprocessor.reset()
|
||||||
postprocessor.reset()
|
postprocessor.reset()
|
||||||
|
|
||||||
|
# Create a buffer for audio observations (shifting window of fixed size over audio samples)
|
||||||
|
if robot.microphones and (policy is not None or dataset is not None):
|
||||||
|
audio_buffer = {
|
||||||
|
microphone_name: np.zeros(
|
||||||
|
(int(microphone.sample_rate * DEFAULT_AUDIO_CHUNK_DURATION), len(microphone.channels))
|
||||||
|
)
|
||||||
|
for microphone_name, microphone in robot.microphones.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
dataset is not None and robot.name != "lekiwi"
|
||||||
|
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||||
|
dataset.add_microphones_recordings(robot.microphones)
|
||||||
|
else:
|
||||||
|
async_microphones_start_recording(robot.microphones)
|
||||||
|
|
||||||
|
# Fill audio buffers if needed
|
||||||
|
if (
|
||||||
|
robot.microphones
|
||||||
|
and (policy is not None or dataset is not None)
|
||||||
|
and DEFAULT_INITIAL_AUDIO_BUFFER_DURATION > 0.0
|
||||||
|
):
|
||||||
|
# This initial wait might be longer than the audio chunk duration to
|
||||||
|
# (1) ensure that the audio buffers are filled with enough data
|
||||||
|
# (2) add additional initial samples to the dataset in case of variable audio chunk duration during training
|
||||||
|
precise_sleep(DEFAULT_INITIAL_AUDIO_BUFFER_DURATION)
|
||||||
|
for microphone_name, microphone in robot.microphones.items():
|
||||||
|
audio_chunk = microphone.read()
|
||||||
|
audio_buffer[microphone_name] = rolling_vstack(audio_buffer[microphone_name], audio_chunk)
|
||||||
|
|
||||||
no_action_count = 0
|
no_action_count = 0
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
@@ -355,8 +409,14 @@ def record_loop(
|
|||||||
|
|
||||||
# Get action from either policy or teleop
|
# Get action from either policy or teleop
|
||||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||||
|
# Transform instantaneous audio samples into a buffer of fixed size
|
||||||
|
buffered_observation_frame = copy(observation_frame)
|
||||||
|
for name in audio_buffer:
|
||||||
|
# Add the audio buffer to the observation
|
||||||
|
buffered_observation_frame[name] = rolling_vstack(audio_buffer[name], observation_frame[name])
|
||||||
|
|
||||||
action_values = predict_action(
|
action_values = predict_action(
|
||||||
observation=observation_frame,
|
observation=buffered_observation_frame,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=get_safe_torch_device(policy.config.device),
|
device=get_safe_torch_device(policy.config.device),
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@@ -415,7 +475,10 @@ def record_loop(
|
|||||||
|
|
||||||
if display_data:
|
if display_data:
|
||||||
log_rerun_data(
|
log_rerun_data(
|
||||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
observation=obs_processed,
|
||||||
|
action=action_values,
|
||||||
|
compress_images=display_compressed_images,
|
||||||
|
log_time=time.perf_counter() - start_episode_t,
|
||||||
)
|
)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
@@ -430,6 +493,8 @@ def record_loop(
|
|||||||
|
|
||||||
timestamp = time.perf_counter() - start_episode_t
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
|
async_microphones_stop_recording(robot.microphones)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no
|
|||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||||
|
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
RobotAction,
|
RobotAction,
|
||||||
RobotObservation,
|
RobotObservation,
|
||||||
@@ -151,8 +153,18 @@ def teleop_loop(
|
|||||||
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
||||||
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
||||||
"""
|
"""
|
||||||
|
if display_data:
|
||||||
|
init_rerun(
|
||||||
|
session_name="teleoperation",
|
||||||
|
robot=robot,
|
||||||
|
reset_time=True,
|
||||||
|
)
|
||||||
|
|
||||||
display_len = max(len(key) for key in robot.action_features)
|
display_len = max(len(key) for key in robot.action_features)
|
||||||
|
|
||||||
|
for _, microphone in robot.microphones.items():
|
||||||
|
microphone.start_recording()
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
@@ -186,6 +198,7 @@ def teleop_loop(
|
|||||||
observation=obs_transition,
|
observation=obs_transition,
|
||||||
action=teleop_action,
|
action=teleop_action,
|
||||||
compress_images=display_compressed_images,
|
compress_images=display_compressed_images,
|
||||||
|
log_time=time.perf_counter() - start,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n" + "-" * (display_len + 10))
|
print("\n" + "-" * (display_len + 10))
|
||||||
@@ -202,7 +215,10 @@ def teleop_loop(
|
|||||||
move_cursor_up(1)
|
move_cursor_up(1)
|
||||||
|
|
||||||
if duration is not None and time.perf_counter() - start >= duration:
|
if duration is not None and time.perf_counter() - start >= duration:
|
||||||
return
|
break
|
||||||
|
|
||||||
|
for _, microphone in robot.microphones.items():
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def rolling_vstack(buffer: np.ndarray, new_data: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rolling implementation of numpy.vstack to add new data in at the end of a fixed shape buffer in a rolling fashion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer: The *fixed* shape buffer to update.
|
||||||
|
new_data: The new data to add to the buffer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buffer_size = buffer.shape[0]
|
||||||
|
# Remove as many old audio samples as needed
|
||||||
|
buffer[: -len(new_data)] = buffer[len(new_data) :]
|
||||||
|
# Add new audio samples, only the newest if the buffer is already full
|
||||||
|
buffer[-len(new_data) :] = new_data[-buffer_size:]
|
||||||
|
return buffer
|
||||||
@@ -23,6 +23,7 @@ OBS_ENV_STATE = OBS_STR + ".environment_state"
|
|||||||
OBS_STATE = OBS_STR + ".state"
|
OBS_STATE = OBS_STR + ".state"
|
||||||
OBS_IMAGE = OBS_STR + ".image"
|
OBS_IMAGE = OBS_STR + ".image"
|
||||||
OBS_IMAGES = OBS_IMAGE + "s"
|
OBS_IMAGES = OBS_IMAGE + "s"
|
||||||
|
OBS_AUDIO = OBS_STR + ".audio"
|
||||||
OBS_LANGUAGE = OBS_STR + ".language"
|
OBS_LANGUAGE = OBS_STR + ".language"
|
||||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ def predict_action(
|
|||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
# Convert to pytorch format: normalizing and permuting (channel first)
|
||||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||||
observation = preprocessor(observation)
|
observation = preprocessor(observation)
|
||||||
|
|
||||||
|
|||||||
@@ -30,3 +30,22 @@ class DeviceAlreadyConnectedError(ConnectionError):
|
|||||||
):
|
):
|
||||||
self.message = message
|
self.message = message
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceNotRecordingError(Exception):
|
||||||
|
"""Exception raised when the robot device is not recording."""
|
||||||
|
|
||||||
|
def __init__(self, message="This robot device is not recording. Try calling `start_recording()` first."):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAlreadyRecordingError(Exception):
|
||||||
|
"""Exception raised when the robot device is already recording."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message="This robot device is already recording. Try not calling `start_recording()` twice.",
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from multiprocessing import Lock, Value, shared_memory
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SharedArray:
|
||||||
|
"""
|
||||||
|
A SharedArray is a numpy array shared between multiple processes in a shared_memory object.
|
||||||
|
- Data is written to the array using the `write` method, which appends data to the array.
|
||||||
|
- Data is read from the array (and eventually flushed) using the `read` method, which copies the _whole_ array.
|
||||||
|
|
||||||
|
SharedArray offers quasi-instantaneous array-wide read and flush capabilities in comparison to Queues, but has a limited size defined at initialization.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
_Main_process_
|
||||||
|
shared_array = SharedArray(shape=(10, 10), dtype=np.dtype("float32"))
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
shared_array.write(local_array, np.array([[1, 2, 3], [4, 5, 6]]))
|
||||||
|
|
||||||
|
_Child_process_
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
data = shared_array.read(local_array, flush=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shape: tuple[int], dtype: np.dtype | str):
|
||||||
|
"""
|
||||||
|
Initialize a SharedArray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the shared array.
|
||||||
|
dtype: The dtype of the shared array.
|
||||||
|
"""
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.shared_memory = shared_memory.SharedMemory(
|
||||||
|
create=True, size=np.prod(shape) * np.dtype(dtype).itemsize
|
||||||
|
)
|
||||||
|
self.read_index = Value("i", 0)
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
|
def get_local_array(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get a process-local instance of the shared array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A process-local instance of the shared array.
|
||||||
|
"""
|
||||||
|
return np.ndarray(self.shape, dtype=np.dtype(self.dtype), buffer=self.shared_memory.buf)
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
"""
|
||||||
|
Delete the shared array.
|
||||||
|
"""
|
||||||
|
self.shared_memory.close()
|
||||||
|
self.shared_memory.unlink()
|
||||||
|
|
||||||
|
def write(self, local_array: np.ndarray, data: np.ndarray):
|
||||||
|
"""
|
||||||
|
Write data to the shared array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_array: The process-local instance of the shared array to write to.
|
||||||
|
data: The data to write to the shared array.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
local_array[self.read_index.value : self.read_index.value + len(data)] = data
|
||||||
|
self.read_index.value += len(data)
|
||||||
|
|
||||||
|
def read(self, local_array: np.ndarray, flush: bool = True) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Read data from the shared array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_array: The process-local instance of the shared array to read from.
|
||||||
|
flush: Whether to flush the shared array after reading.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
data = np.copy(local_array[: self.read_index.value])
|
||||||
|
if flush:
|
||||||
|
self.read_index.value = 0
|
||||||
|
return data
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the read index to 0.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.read_index.value = 0
|
||||||
@@ -14,17 +14,25 @@
|
|||||||
|
|
||||||
import numbers
|
import numbers
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||||
|
from lerobot.processor import RobotAction, RobotObservation
|
||||||
|
from lerobot.robots import Robot
|
||||||
|
|
||||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||||
|
|
||||||
|
|
||||||
def init_rerun(
|
def init_rerun(
|
||||||
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
|
session_name: str = "lerobot_control_loop",
|
||||||
|
ip: str | None = None,
|
||||||
|
port: int | None = None,
|
||||||
|
robot: Robot | None = None,
|
||||||
|
reset_time: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the Rerun SDK for visualizing the control loop.
|
Initializes the Rerun SDK for visualizing the control loop.
|
||||||
@@ -33,16 +41,26 @@ def init_rerun(
|
|||||||
session_name: Name of the Rerun session.
|
session_name: Name of the Rerun session.
|
||||||
ip: Optional IP for connecting to a Rerun server.
|
ip: Optional IP for connecting to a Rerun server.
|
||||||
port: Optional port for connecting to a Rerun server.
|
port: Optional port for connecting to a Rerun server.
|
||||||
|
robot: A Robot object. If provided, Rerun will be initialized with a blueprint that includes the object's cameras and microphones.
|
||||||
|
reset_time: Whether to reset the timer "episode_time" to 0.
|
||||||
"""
|
"""
|
||||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||||
rr.init(session_name)
|
rr.init(
|
||||||
|
application_id=session_name,
|
||||||
|
recording_id=uuid4(),
|
||||||
|
)
|
||||||
|
if robot is not None:
|
||||||
|
rr.send_blueprint(build_rerun_blueprint(robot))
|
||||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||||
if ip and port:
|
if ip and port:
|
||||||
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
|
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
|
||||||
else:
|
else:
|
||||||
rr.spawn(memory_limit=memory_limit)
|
rr.spawn(memory_limit=memory_limit)
|
||||||
|
|
||||||
|
if reset_time:
|
||||||
|
rr.set_time("episode_time", timestamp=0.0)
|
||||||
|
|
||||||
|
|
||||||
def _is_scalar(x):
|
def _is_scalar(x):
|
||||||
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
|
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
|
||||||
@@ -50,10 +68,47 @@ def _is_scalar(x):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_rerun_blueprint(robot: Robot) -> rr.blueprint.Grid:
|
||||||
|
""" "
|
||||||
|
Builds a Rerun blueprint for optimized visualization of the robot's observations and actions :
|
||||||
|
- Time series views for all scalar observations and actions (e.g. position, velocity, torque, etc.).
|
||||||
|
- Spatial 2D views for all camera observations.
|
||||||
|
- Time series views for all microphone observations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: A Robot object.
|
||||||
|
Returns:
|
||||||
|
A Rerun blueprint.
|
||||||
|
"""
|
||||||
|
contents = [
|
||||||
|
rr.blueprint.TimeSeriesView(
|
||||||
|
origin="data",
|
||||||
|
plot_legend=rr.blueprint.PlotLegend(visible=True),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if robot.microphones:
|
||||||
|
contents += [
|
||||||
|
rr.blueprint.TimeSeriesView(
|
||||||
|
origin="audio",
|
||||||
|
plot_legend=rr.blueprint.PlotLegend(visible=True),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if robot.cameras:
|
||||||
|
contents += [
|
||||||
|
rr.blueprint.Spatial2DView(
|
||||||
|
origin=OBS_PREFIX + camera_name,
|
||||||
|
)
|
||||||
|
for camera_name in robot.cameras
|
||||||
|
]
|
||||||
|
|
||||||
|
return rr.blueprint.Grid(*contents)
|
||||||
|
|
||||||
|
|
||||||
def log_rerun_data(
|
def log_rerun_data(
|
||||||
observation: RobotObservation | None = None,
|
observation: RobotObservation | None = None,
|
||||||
action: RobotAction | None = None,
|
action: RobotAction | None = None,
|
||||||
compress_images: bool = False,
|
compress_images: bool = False,
|
||||||
|
log_time: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Logs observation and action data to Rerun for real-time visualization.
|
Logs observation and action data to Rerun for real-time visualization.
|
||||||
@@ -72,7 +127,13 @@ def log_rerun_data(
|
|||||||
observation: An optional dictionary containing observation data to log.
|
observation: An optional dictionary containing observation data to log.
|
||||||
action: An optional dictionary containing action data to log.
|
action: An optional dictionary containing action data to log.
|
||||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||||
|
log_time: The time to log the data in the "episode_time" timeline.
|
||||||
|
If None, the current time is used in Rerun's default timeline.
|
||||||
"""
|
"""
|
||||||
|
if log_time is None:
|
||||||
|
log_time = time.perf_counter()
|
||||||
|
rr.set_time("episode_time", timestamp=log_time)
|
||||||
|
|
||||||
if observation:
|
if observation:
|
||||||
for k, v in observation.items():
|
for k, v in observation.items():
|
||||||
if v is None:
|
if v is None:
|
||||||
@@ -80,15 +141,41 @@ def log_rerun_data(
|
|||||||
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
||||||
|
|
||||||
if _is_scalar(v):
|
if _is_scalar(v):
|
||||||
rr.log(key, rr.Scalars(float(v)))
|
rr.log("data/" + key, rr.Scalars(float(v)))
|
||||||
elif isinstance(v, np.ndarray):
|
elif isinstance(v, np.ndarray):
|
||||||
arr = v
|
arr = v
|
||||||
# Convert CHW -> HWC when needed
|
# Convert CHW -> HWC when needed
|
||||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||||
arr = np.transpose(arr, (1, 2, 0))
|
arr = np.transpose(arr, (1, 2, 0))
|
||||||
|
# Convert samples x channels -> channels x samples when needed
|
||||||
|
elif arr.ndim == 2 and arr.shape[1] < arr.shape[0]:
|
||||||
|
arr = np.transpose(arr, (1, 0))
|
||||||
|
|
||||||
if arr.ndim == 1:
|
if arr.ndim == 1:
|
||||||
for i, vi in enumerate(arr):
|
for i, vi in enumerate(arr):
|
||||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||||
|
elif arr.ndim == 2:
|
||||||
|
for i, channel_arr in enumerate(arr):
|
||||||
|
rr.send_columns(
|
||||||
|
"audio/"
|
||||||
|
+ key
|
||||||
|
+ f"_channel_{i}", # TODO(CarolinePascal): Get actual channel number/name
|
||||||
|
indexes=[
|
||||||
|
rr.TimeColumn(
|
||||||
|
"episode_time",
|
||||||
|
timestamp=log_time
|
||||||
|
+ np.linspace(
|
||||||
|
-DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
|
0,
|
||||||
|
len(channel_arr),
|
||||||
|
endpoint=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
columns=rr.Scalars.columns(scalars=channel_arr),
|
||||||
|
)
|
||||||
|
elif arr.ndim == 3:
|
||||||
|
rr.log(key, rr.Image(arr), static=True)
|
||||||
else:
|
else:
|
||||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||||
rr.log(key, entity=img_entity, static=True)
|
rr.log(key, entity=img_entity, static=True)
|
||||||
@@ -100,13 +187,13 @@ def log_rerun_data(
|
|||||||
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
|
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
|
||||||
|
|
||||||
if _is_scalar(v):
|
if _is_scalar(v):
|
||||||
rr.log(key, rr.Scalars(float(v)))
|
rr.log("data/" + key, rr.Scalars(float(v)))
|
||||||
elif isinstance(v, np.ndarray):
|
elif isinstance(v, np.ndarray):
|
||||||
if v.ndim == 1:
|
if v.ndim == 1:
|
||||||
for i, vi in enumerate(v):
|
for i, vi in enumerate(v):
|
||||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||||
else:
|
else:
|
||||||
# Fall back to flattening higher-dimensional arrays
|
# Fall back to flattening higher-dimensional arrays
|
||||||
flat = v.flatten()
|
flat = v.flatten()
|
||||||
for i, vi in enumerate(flat):
|
for i, vi in enumerate(flat):
|
||||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ def _check_component_availability(component_type, available_components, make_com
|
|||||||
print("\nNo physical device detected.")
|
print("\nNo physical device detected.")
|
||||||
elif isinstance(e, ValueError) and "camera_index" in str(e):
|
elif isinstance(e, ValueError) and "camera_index" in str(e):
|
||||||
print("\nNo physical camera detected.")
|
print("\nNo physical camera detected.")
|
||||||
|
elif isinstance(e, ValueError) and "microphone_index" in str(e):
|
||||||
|
print("\nNo physical microphone detected.")
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|||||||
@@ -26,16 +26,22 @@ from lerobot.datasets.compute_stats import (
|
|||||||
compute_episode_stats,
|
compute_episode_stats,
|
||||||
estimate_num_samples,
|
estimate_num_samples,
|
||||||
get_feature_stats,
|
get_feature_stats,
|
||||||
|
sample_audio_from_data,
|
||||||
|
sample_audio_from_path,
|
||||||
sample_images,
|
sample_images,
|
||||||
sample_indices,
|
sample_indices,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
from lerobot.utils.constants import OBS_AUDIO, OBS_IMAGE, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_load_audio(path):
|
||||||
|
return np.ones((16000, 2), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_array():
|
def sample_array():
|
||||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
@@ -73,6 +79,25 @@ def test_sample_images(mock_load):
|
|||||||
assert len(images) == estimate_num_samples(100)
|
assert len(images) == estimate_num_samples(100)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||||
|
def test_sample_audio_from_path(mock_load):
|
||||||
|
audio_path = "audio.wav"
|
||||||
|
audio_samples = sample_audio_from_path(audio_path)
|
||||||
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
|
assert audio_samples.shape[1] == 2
|
||||||
|
assert audio_samples.dtype == np.float32
|
||||||
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_audio_from_data():
|
||||||
|
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||||
|
audio_samples = sample_audio_from_data(audio_data)
|
||||||
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
|
assert audio_samples.shape[1] == 2
|
||||||
|
assert audio_samples.dtype == np.float32
|
||||||
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_images():
|
def test_get_feature_stats_images():
|
||||||
data = np.random.rand(100, 3, 32, 32)
|
data = np.random.rand(100, 3, 32, 32)
|
||||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||||
@@ -81,6 +106,14 @@ def test_get_feature_stats_images():
|
|||||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_feature_stats_audio():
|
||||||
|
data = np.random.uniform(-1, 1, (16000, 2))
|
||||||
|
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||||
|
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||||
|
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||||
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||||
expected = {
|
expected = {
|
||||||
"min": np.array([[1, 2, 3]]),
|
"min": np.array([[1, 2, 3]]),
|
||||||
@@ -145,20 +178,27 @@ def test_get_feature_stats_single_value():
|
|||||||
def test_compute_episode_stats():
|
def test_compute_episode_stats():
|
||||||
episode_data = {
|
episode_data = {
|
||||||
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
||||||
|
OBS_AUDIO: "audio.wav",
|
||||||
OBS_STATE: np.random.rand(100, 10),
|
OBS_STATE: np.random.rand(100, 10),
|
||||||
}
|
}
|
||||||
features = {
|
features = {
|
||||||
OBS_IMAGE: {"dtype": "image"},
|
OBS_IMAGE: {"dtype": "image"},
|
||||||
|
OBS_AUDIO: {"dtype": "audio"},
|
||||||
OBS_STATE: {"dtype": "numeric"},
|
OBS_STATE: {"dtype": "numeric"},
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
with (
|
||||||
|
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
||||||
|
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||||
|
):
|
||||||
stats = compute_episode_stats(episode_data, features)
|
stats = compute_episode_stats(episode_data, features)
|
||||||
|
|
||||||
assert OBS_IMAGE in stats and OBS_STATE in stats
|
assert OBS_IMAGE in stats and OBS_AUDIO in stats and OBS_STATE in stats
|
||||||
assert stats[OBS_IMAGE]["count"].item() == 100
|
assert stats[OBS_IMAGE]["count"].item() == estimate_num_samples(100)
|
||||||
assert stats[OBS_STATE]["count"].item() == 100
|
assert stats[OBS_AUDIO]["count"].item() == estimate_num_samples(16000)
|
||||||
|
assert stats[OBS_STATE]["count"].item() == estimate_num_samples(100)
|
||||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||||
|
assert stats[OBS_AUDIO]["mean"].shape == (1, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_assert_type_and_shape_valid():
|
def test_assert_type_and_shape_valid():
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
from soundfile import write
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
@@ -35,6 +36,7 @@ from lerobot.datasets.io_utils import hf_transform_to_torch
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
@@ -45,7 +47,13 @@ from lerobot.envs.factory import make_env_config
|
|||||||
from lerobot.policies.factory import make_policy_config
|
from lerobot.policies.factory import make_policy_config
|
||||||
from lerobot.robots import make_robot_from_config
|
from lerobot.robots import make_robot_from_config
|
||||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import (
|
||||||
|
DEFAULT_SAMPLE_RATE,
|
||||||
|
DUMMY_AUDIO_CHANNELS,
|
||||||
|
DUMMY_CHW,
|
||||||
|
DUMMY_HWC,
|
||||||
|
DUMMY_REPO_ID,
|
||||||
|
)
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
@@ -66,6 +74,36 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {
|
||||||
|
"audio": {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {
|
||||||
|
"audio": {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||||
"""
|
"""
|
||||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||||
@@ -420,6 +458,78 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_array(audio_dataset_le_kiwi):
|
||||||
|
dataset = audio_dataset_le_kiwi
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"audio": np.random.rand(
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||||
|
)
|
||||||
|
},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset[0]["audio"].shape == torch.Size(
|
||||||
|
(
|
||||||
|
DUMMY_AUDIO_CHANNELS,
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
|
||||||
|
dataset = audio_dataset_le_kiwi
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"audio": np.random.rand(
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
|
||||||
|
)
|
||||||
|
},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
|
||||||
|
dataset = audio_dataset_le_kiwi
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset.add_frame(
|
||||||
|
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_file(audio_dataset):
|
||||||
|
dataset = audio_dataset
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"audio": np.random.rand(
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||||
|
)
|
||||||
|
},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
# Create the audio file that should be created in the background by the Microphone class
|
||||||
|
for audio_key in dataset.meta.audio_keys:
|
||||||
|
fpath = dataset.writer._get_raw_audio_file_path(0, audio_key)
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
write(
|
||||||
|
fpath,
|
||||||
|
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
|
||||||
|
DEFAULT_SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset[0]["audio"].shape == torch.Size(
|
||||||
|
(
|
||||||
|
DUMMY_AUDIO_CHANNELS,
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts):
|
# TODO(aliberts):
|
||||||
# - [ ] test various attributes & state from init and create
|
# - [ ] test various attributes & state from init and create
|
||||||
# - [ ] test init with episodes and check num_frames
|
# - [ ] test init with episodes and check num_frames
|
||||||
@@ -459,6 +569,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
delta_timestamps = dataset.delta_timestamps
|
delta_timestamps = dataset.delta_timestamps
|
||||||
camera_keys = dataset.meta.camera_keys
|
camera_keys = dataset.meta.camera_keys
|
||||||
|
audio_keys = dataset.meta.audio_keys
|
||||||
|
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
@@ -501,6 +612,11 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
# test c,h,w
|
# test c,h,w
|
||||||
assert item[key].shape[0] == 3, f"{key}"
|
assert item[key].shape[0] == 3, f"{key}"
|
||||||
|
|
||||||
|
for key in audio_keys:
|
||||||
|
assert item[key].dtype == torch.float32, f"{key}"
|
||||||
|
assert item[key].max() <= 1.0, f"{key}"
|
||||||
|
assert item[key].min() >= -1.0, f"{key}"
|
||||||
|
|
||||||
if delta_timestamps is not None:
|
if delta_timestamps is not None:
|
||||||
# test missing keys in delta_timestamps
|
# test missing keys in delta_timestamps
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
|
|||||||
Vendored
+13
@@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = {
|
|||||||
"video.is_depth_map": False,
|
"video.is_depth_map": False,
|
||||||
"has_audio": False,
|
"has_audio": False,
|
||||||
}
|
}
|
||||||
|
DUMMY_MICROPHONE_FEATURES = {
|
||||||
|
"laptop": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
|
||||||
|
"phone": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
|
||||||
|
}
|
||||||
|
DEFAULT_SAMPLE_RATE = 48000
|
||||||
|
DUMMY_AUDIO_CHANNELS = 2
|
||||||
|
DUMMY_AUDIO_INFO = {
|
||||||
|
"has_audio": True,
|
||||||
|
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
|
||||||
|
"audio.codec": "aac",
|
||||||
|
"audio.channels": DUMMY_AUDIO_CHANNELS,
|
||||||
|
"audio.channel_layout": "stereo",
|
||||||
|
}
|
||||||
DUMMY_CHW = (3, 96, 128)
|
DUMMY_CHW = (3, 96, 128)
|
||||||
DUMMY_HWC = (96, 128, 3)
|
DUMMY_HWC = (96, 128, 3)
|
||||||
|
|||||||
Vendored
+18
-1
@@ -31,6 +31,7 @@ from lerobot.datasets.feature_utils import get_hf_features_from_features
|
|||||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_PATH,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
@@ -43,6 +44,7 @@ from lerobot.datasets.video_utils import encode_video_frames
|
|||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
|
DUMMY_MICROPHONE_FEATURES,
|
||||||
DUMMY_MOTOR_FEATURES,
|
DUMMY_MOTOR_FEATURES,
|
||||||
DUMMY_REPO_ID,
|
DUMMY_REPO_ID,
|
||||||
DUMMY_ROBOT_TYPE,
|
DUMMY_ROBOT_TYPE,
|
||||||
@@ -131,6 +133,7 @@ def features_factory():
|
|||||||
def _create_features(
|
def _create_features(
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
|
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if use_videos:
|
if use_videos:
|
||||||
@@ -142,6 +145,7 @@ def features_factory():
|
|||||||
return {
|
return {
|
||||||
**motor_features,
|
**motor_features,
|
||||||
**camera_ft,
|
**camera_ft,
|
||||||
|
**audio_features,
|
||||||
**DEFAULT_FEATURES,
|
**DEFAULT_FEATURES,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,16 +162,19 @@ def info_factory(features_factory):
|
|||||||
total_frames: int = 0,
|
total_frames: int = 0,
|
||||||
total_tasks: int = 0,
|
total_tasks: int = 0,
|
||||||
total_videos: int = 0,
|
total_videos: int = 0,
|
||||||
|
total_audio: int = 0,
|
||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
data_path: str = DEFAULT_DATA_PATH,
|
data_path: str = DEFAULT_DATA_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
|
audio_path: str = DEFAULT_AUDIO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
|
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
features = features_factory(motor_features, camera_features, audio_features, use_videos)
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
@@ -175,6 +182,7 @@ def info_factory(features_factory):
|
|||||||
"total_frames": total_frames,
|
"total_frames": total_frames,
|
||||||
"total_tasks": total_tasks,
|
"total_tasks": total_tasks,
|
||||||
"total_videos": total_videos,
|
"total_videos": total_videos,
|
||||||
|
"total_audio": total_audio,
|
||||||
"chunks_size": chunks_size,
|
"chunks_size": chunks_size,
|
||||||
"data_files_size_in_mb": data_files_size_in_mb,
|
"data_files_size_in_mb": data_files_size_in_mb,
|
||||||
"video_files_size_in_mb": video_files_size_in_mb,
|
"video_files_size_in_mb": video_files_size_in_mb,
|
||||||
@@ -182,6 +190,7 @@ def info_factory(features_factory):
|
|||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": data_path,
|
"data_path": data_path,
|
||||||
"video_path": video_path if use_videos else None,
|
"video_path": video_path if use_videos else None,
|
||||||
|
"audio_path": audio_path,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +214,14 @@ def stats_factory():
|
|||||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||||
"count": [10],
|
"count": [10],
|
||||||
}
|
}
|
||||||
|
elif dtype == "audio":
|
||||||
|
stats[key] = {
|
||||||
|
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
|
||||||
|
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
|
||||||
|
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
|
||||||
|
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
|
||||||
|
"count": [10],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
stats[key] = {
|
stats[key] = {
|
||||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||||
|
|||||||
@@ -0,0 +1,532 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundfile import read
|
||||||
|
|
||||||
|
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig
|
||||||
|
from lerobot.microphones.portaudio.interface_sounddevice_sdk import (
|
||||||
|
FakeSounddeviceSDKAdapter,
|
||||||
|
SounddeviceSDKAdapter,
|
||||||
|
)
|
||||||
|
from lerobot.microphones.portaudio.microphone_portaudio import PortAudioMicrophone
|
||||||
|
from lerobot.microphones.utils import async_microphones_start_recording, async_microphones_stop_recording
|
||||||
|
from lerobot.utils.errors import (
|
||||||
|
DeviceAlreadyConnectedError,
|
||||||
|
DeviceAlreadyRecordingError,
|
||||||
|
DeviceNotConnectedError,
|
||||||
|
DeviceNotRecordingError,
|
||||||
|
)
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
|
MODULE_PATH = "lerobot.microphones.portaudio.microphone_portaudio"
|
||||||
|
RECORDING_DURATION = 1.0
|
||||||
|
|
||||||
|
LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS = (
|
||||||
|
os.getenv("LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_sdk():
|
||||||
|
"""Fixture to provide either real or fake SDK based on environment variable."""
|
||||||
|
if LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS:
|
||||||
|
return SounddeviceSDKAdapter()
|
||||||
|
else:
|
||||||
|
return FakeSounddeviceSDKAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_creation():
|
||||||
|
"""Test creating a valid configuration."""
|
||||||
|
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000, channels=[1, 2])
|
||||||
|
assert config.microphone_index == 0
|
||||||
|
assert config.sample_rate == 48000
|
||||||
|
assert config.channels == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_creation_missing_microphone_index():
|
||||||
|
"""Test creating a configuration with missing microphone index."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
PortAudioMicrophoneConfig(sample_rate=48000, channels=[1, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_creation_missing_sample_rate():
|
||||||
|
"""Test creating a configuration with missing sample rate."""
|
||||||
|
config = PortAudioMicrophoneConfig(microphone_index=0, channels=[1, 2])
|
||||||
|
assert config.sample_rate is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_creation_missing_channels():
|
||||||
|
"""Test creating a configuration with missing channels."""
|
||||||
|
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000)
|
||||||
|
assert config.channels is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_config(test_sdk):
|
||||||
|
"""Fixture to provide a default configuration for input devices."""
|
||||||
|
device_info = test_sdk.query_devices(kind="input")
|
||||||
|
return PortAudioMicrophoneConfig(
|
||||||
|
microphone_index=device_info["index"],
|
||||||
|
sample_rate=device_info["default_samplerate"],
|
||||||
|
channels=np.arange(device_info["max_input_channels"]) + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Microphone Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_microphones(test_sdk):
|
||||||
|
"""Test finding microphones."""
|
||||||
|
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
for microphone in microphones:
|
||||||
|
assert isinstance(microphone["index"], int)
|
||||||
|
assert isinstance(microphone["name"], str)
|
||||||
|
assert isinstance(microphone["sample_rate"], int)
|
||||||
|
assert isinstance(microphone["channels"], np.ndarray)
|
||||||
|
assert len(microphone["channels"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_defaults(default_config, test_sdk):
|
||||||
|
"""Test microphone initialization with defaults."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
device_info = test_sdk.query_devices(kind="input")
|
||||||
|
assert microphone is not None
|
||||||
|
assert microphone.microphone_index == device_info["index"]
|
||||||
|
assert microphone.sample_rate == device_info["default_samplerate"]
|
||||||
|
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
|
||||||
|
assert not microphone.is_connected
|
||||||
|
assert not microphone.is_recording
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_success(default_config, test_sdk):
|
||||||
|
"""Test successful connection."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_empty_config(default_config, test_sdk):
|
||||||
|
"""Test connection with empty config values."""
|
||||||
|
config = deepcopy(default_config)
|
||||||
|
config.sample_rate = None
|
||||||
|
config.channels = None
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
device_info = test_sdk.query_devices(kind="input")
|
||||||
|
assert microphone.sample_rate == device_info["default_samplerate"]
|
||||||
|
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_already_connected(default_config, test_sdk):
|
||||||
|
"""Test connecting when already connected."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
with pytest.raises(DeviceAlreadyConnectedError):
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_invalid_device(test_sdk):
|
||||||
|
"""Test connecting with invalid device (output device)."""
|
||||||
|
device_info = test_sdk.query_devices(kind="output")
|
||||||
|
config = PortAudioMicrophoneConfig(
|
||||||
|
microphone_index=device_info["index"],
|
||||||
|
sample_rate=device_info["default_samplerate"],
|
||||||
|
channels=np.arange(device_info["max_input_channels"]) + 1,
|
||||||
|
)
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_invalid_index(default_config, test_sdk):
|
||||||
|
"""Test connecting with invalid device index."""
|
||||||
|
config = deepcopy(default_config)
|
||||||
|
config.microphone_index = -1
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_invalid_sample_rate(default_config, test_sdk):
|
||||||
|
"""Test connecting with invalid sample rate."""
|
||||||
|
config = deepcopy(default_config)
|
||||||
|
config.sample_rate = -1
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_float_sample_rate(default_config, test_sdk):
|
||||||
|
"""Test connecting with float sample rate."""
|
||||||
|
config = deepcopy(default_config)
|
||||||
|
config.sample_rate = int(config.sample_rate) - 0.5
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
assert isinstance(microphone.sample_rate, int)
|
||||||
|
assert microphone.sample_rate == int(config.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_lower_sample_rate(default_config, test_sdk):
|
||||||
|
"""Test connecting with lower sample rate."""
|
||||||
|
config = deepcopy(default_config)
|
||||||
|
config.sample_rate = 1000 # Lowest possible sample rate
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
microphone.connect()
|
||||||
|
assert microphone.sample_rate == 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_invalid_channels(default_config, test_sdk):
|
||||||
|
"""Test connecting with invalid channels."""
|
||||||
|
config = deepcopy(default_config)
|
||||||
|
config.channels = np.append(default_config.channels, -1)
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_success(default_config, test_sdk):
|
||||||
|
"""Test successful disconnection."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
assert not microphone.is_connected
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_not_connected(default_config, test_sdk):
|
||||||
|
"""Test disconnecting when not connected."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(DeviceNotConnectedError):
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_start_recording_success(default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test successful recording start."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
assert microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_recording_not_connected(default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test starting recording when not connected."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(DeviceNotConnectedError):
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_start_recording_already_recording(default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test starting recording when already recording."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
with pytest.raises(DeviceAlreadyRecordingError):
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_start_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test successful writing start."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
assert microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert microphone.is_writing
|
||||||
|
assert (tmp_path / "test.wav").exists()
|
||||||
|
|
||||||
|
(tmp_path / "test.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_start_writing_file_already_exists_no_overwrite(tmp_path, default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test writing with file that already exists."""
|
||||||
|
(tmp_path / "test.wav").touch()
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
with pytest.raises(FileExistsError):
|
||||||
|
microphone.start_recording(
|
||||||
|
output_file=tmp_path / "test.wav", multiprocessing=multiprocessing, overwrite=False
|
||||||
|
)
|
||||||
|
|
||||||
|
(tmp_path / "test.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_stop_recording_success(default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test successful recording stop."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_stop_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test successful writing stop."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_writing
|
||||||
|
assert (tmp_path / "test.wav").exists()
|
||||||
|
|
||||||
|
(tmp_path / "test.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_recording_not_connected(default_config, test_sdk):
|
||||||
|
"""Test stopping recording when not connected."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
|
||||||
|
with pytest.raises(DeviceNotConnectedError):
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_recording_not_recording(default_config, test_sdk):
|
||||||
|
"""Test stopping recording when not recording."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
with pytest.raises(DeviceNotRecordingError):
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_disconnect_while_recording(default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test disconnecting while recording."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
assert not microphone.is_connected
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_disconnect_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test disconnecting while writing."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
assert not microphone.is_connected
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert not microphone.is_writing
|
||||||
|
assert Path(tmp_path / "test.wav").exists()
|
||||||
|
|
||||||
|
(tmp_path / "test.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_read_success(default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test successful reading of audio data."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
|
||||||
|
data = microphone.read()
|
||||||
|
|
||||||
|
device_info = test_sdk.query_devices(kind="input")
|
||||||
|
assert data is not None
|
||||||
|
assert data.shape[1] == len(default_config.channels)
|
||||||
|
assert (
|
||||||
|
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||||
|
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test successful writing to file."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
data, samplerate = read(tmp_path / "test.wav")
|
||||||
|
|
||||||
|
device_info = test_sdk.query_devices(kind="input")
|
||||||
|
assert samplerate == default_config.sample_rate
|
||||||
|
assert data.shape[1] == len(default_config.channels)
|
||||||
|
assert (
|
||||||
|
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||||
|
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||||
|
)
|
||||||
|
|
||||||
|
(tmp_path / "test.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||||
|
def test_read_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
|
||||||
|
"""Test reading while writing."""
|
||||||
|
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||||
|
|
||||||
|
precise_sleep(RECORDING_DURATION)
|
||||||
|
|
||||||
|
read_data = microphone.read()
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
writing_data, _ = read(tmp_path / "test.wav")
|
||||||
|
|
||||||
|
device_info = test_sdk.query_devices(kind="input")
|
||||||
|
assert (
|
||||||
|
abs(writing_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||||
|
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
abs(read_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||||
|
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||||
|
)
|
||||||
|
|
||||||
|
(tmp_path / "test.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_start_recording(default_config, test_sdk):
|
||||||
|
"""Test async recording start."""
|
||||||
|
microphones = {
|
||||||
|
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
}
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
async_microphones_start_recording(microphones)
|
||||||
|
|
||||||
|
for microphone in microphones.values():
|
||||||
|
assert microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_start_writing(tmp_path, default_config, test_sdk):
|
||||||
|
"""Test async writing start."""
|
||||||
|
microphones = {
|
||||||
|
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
}
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
async_microphones_start_recording(
|
||||||
|
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
|
||||||
|
)
|
||||||
|
|
||||||
|
for microphone in microphones.values():
|
||||||
|
assert microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert microphone.is_writing
|
||||||
|
assert Path(tmp_path / "test_1.wav").exists()
|
||||||
|
assert Path(tmp_path / "test_2.wav").exists()
|
||||||
|
|
||||||
|
(tmp_path / "test_1.wav").unlink()
|
||||||
|
(tmp_path / "test_2.wav").unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_stop_recording(default_config, test_sdk):
|
||||||
|
"""Test async recording stop."""
|
||||||
|
microphones = {
|
||||||
|
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
}
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
async_microphones_start_recording(microphones)
|
||||||
|
async_microphones_stop_recording(microphones)
|
||||||
|
|
||||||
|
for microphone in microphones.values():
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_writing
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_stop_writing(tmp_path, default_config, test_sdk):
|
||||||
|
"""Test async writing stop."""
|
||||||
|
microphones = {
|
||||||
|
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||||
|
}
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
async_microphones_start_recording(
|
||||||
|
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
|
||||||
|
)
|
||||||
|
async_microphones_stop_recording(microphones)
|
||||||
|
|
||||||
|
for microphone in microphones.values():
|
||||||
|
assert not microphone.is_recording
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert not microphone.is_writing
|
||||||
|
assert Path(tmp_path / "test_1.wav").exists()
|
||||||
|
assert Path(tmp_path / "test_2.wav").exists()
|
||||||
|
|
||||||
|
(tmp_path / "test_1.wav").unlink()
|
||||||
|
(tmp_path / "test_2.wav").unlink()
|
||||||
@@ -0,0 +1,508 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
from multiprocessing import Event, Process, Queue
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.utils.shared_array import SharedArray
|
||||||
|
|
||||||
|
|
||||||
|
def writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||||
|
"""Writer process that continuously writes data to shared array."""
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Wait for all processes to be ready
|
||||||
|
barrier.wait()
|
||||||
|
|
||||||
|
write_count = 0
|
||||||
|
while not stop_event.is_set() and write_count < 10:
|
||||||
|
# Generate unique data for this process and write iteration
|
||||||
|
data = np.full((5, 2), process_id * 100 + write_count, dtype=np.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
data_queue.put(f"writer_{process_id}_wrote_{write_count}")
|
||||||
|
write_count += 1
|
||||||
|
time.sleep(0.01) # Small delay to allow race conditions
|
||||||
|
except IndexError:
|
||||||
|
# Array is full, stop writing
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def reader_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||||
|
"""Reader process that continuously reads data from shared array."""
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Wait for all processes to be ready
|
||||||
|
barrier.wait()
|
||||||
|
|
||||||
|
read_count = 0
|
||||||
|
while not stop_event.is_set() and read_count < 5:
|
||||||
|
time.sleep(0.02) # Allow some writes to accumulate
|
||||||
|
|
||||||
|
data = shared_array.read(local_array, flush=True)
|
||||||
|
data_queue.put(f"reader_{process_id}_read_{len(data)}_items")
|
||||||
|
read_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def stress_writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||||
|
"""High-frequency writer process for stress testing."""
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
barrier.wait()
|
||||||
|
|
||||||
|
write_count = 0
|
||||||
|
while not stop_event.is_set() and write_count < 50:
|
||||||
|
# Write single row at a time for more frequent operations
|
||||||
|
data = np.array([[process_id, write_count]], dtype=np.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
write_count += 1
|
||||||
|
# No sleep - stress test
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_queue.put(f"stress_writer_{process_id}_completed_{write_count}")
|
||||||
|
|
||||||
|
|
||||||
|
# Basic functionality tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_array_creation():
|
||||||
|
"""Test basic SharedArray creation and properties."""
|
||||||
|
shape = (100, 4)
|
||||||
|
dtype = np.float32
|
||||||
|
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
assert shared_array.shape == shape
|
||||||
|
assert shared_array.dtype == dtype
|
||||||
|
assert shared_array.read_index.value == 0
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_array_access():
|
||||||
|
"""Test getting local array instances."""
|
||||||
|
shape = (50, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
assert local_array.shape == shape
|
||||||
|
assert local_array.dtype == np.float32
|
||||||
|
assert isinstance(local_array, np.ndarray)
|
||||||
|
|
||||||
|
# Test that we can get multiple local array instances
|
||||||
|
local_array2 = shared_array.get_local_array()
|
||||||
|
assert local_array2.shape == shape
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_and_read_single_process():
|
||||||
|
"""Test basic write and read operations in single process."""
|
||||||
|
shape = (20, 3)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Write some data
|
||||||
|
data1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data1)
|
||||||
|
|
||||||
|
assert shared_array.read_index.value == 2
|
||||||
|
|
||||||
|
# Write more data
|
||||||
|
data2 = np.array([[7, 8, 9]], dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data2)
|
||||||
|
|
||||||
|
assert shared_array.read_index.value == 3
|
||||||
|
|
||||||
|
# Read all data
|
||||||
|
read_data = shared_array.read(local_array, flush=False)
|
||||||
|
expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
|
||||||
|
np.testing.assert_array_equal(read_data, expected)
|
||||||
|
|
||||||
|
# Read with flush
|
||||||
|
read_data_flush = shared_array.read(local_array, flush=True)
|
||||||
|
np.testing.assert_array_equal(read_data_flush, expected)
|
||||||
|
assert shared_array.read_index.value == 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_overflow():
|
||||||
|
"""Test behavior when writing more data than array capacity."""
|
||||||
|
shape = (5, 2) # Small array
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Fill the array
|
||||||
|
data = np.ones((5, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
|
||||||
|
# Try to write more data - should raise IndexError
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
extra_data = np.ones((2, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, extra_data)
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_functionality():
|
||||||
|
"""Test the reset method."""
|
||||||
|
shape = (10, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Write some data
|
||||||
|
data = np.ones((3, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
assert shared_array.read_index.value == 3
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
shared_array.reset()
|
||||||
|
assert shared_array.read_index.value == 0
|
||||||
|
|
||||||
|
# Read should return empty array
|
||||||
|
read_data = shared_array.read(local_array, flush=False)
|
||||||
|
assert len(read_data) == 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
# Multi-process tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_writer_single_reader():
|
||||||
|
"""Test basic writer-reader scenario with one process each."""
|
||||||
|
shape = (100, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
barrier = multiprocessing.Barrier(2) # Writer + reader
|
||||||
|
|
||||||
|
# Start writer process
|
||||||
|
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
|
||||||
|
# Start reader process
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
|
||||||
|
writer.start()
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run for a bit
|
||||||
|
time.sleep(0.5)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
writer.join(timeout=2.0)
|
||||||
|
reader.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Verify both processes completed
|
||||||
|
assert not writer.is_alive()
|
||||||
|
assert not reader.is_alive()
|
||||||
|
|
||||||
|
# Check that we got messages from both processes
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||||
|
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||||
|
|
||||||
|
assert len(writer_messages) > 0
|
||||||
|
assert len(reader_messages) > 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_writers_single_reader():
|
||||||
|
"""Test multiple writers with single reader - check for race conditions."""
|
||||||
|
shape = (200, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
num_writers = 3
|
||||||
|
barrier = multiprocessing.Barrier(num_writers + 1) # Writers + reader
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
# Start multiple writer processes
|
||||||
|
for i in range(num_writers):
|
||||||
|
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||||
|
processes.append(writer)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
# Start reader process
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
processes.append(reader)
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run
|
||||||
|
time.sleep(1.0)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for all processes
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=3.0)
|
||||||
|
assert not process.is_alive()
|
||||||
|
|
||||||
|
# Verify we got messages from all processes
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||||
|
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||||
|
|
||||||
|
# Should have messages from all writers
|
||||||
|
assert len(writer_messages) >= num_writers
|
||||||
|
assert len(reader_messages) > 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_integrity_with_concurrent_access():
|
||||||
|
"""Test that data integrity is maintained under concurrent access using standard reader/writer processes."""
|
||||||
|
shape = (500, 2) # Use standard 2-column format
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
barrier = multiprocessing.Barrier(3) # 2 writers + 1 reader
|
||||||
|
|
||||||
|
# Start two writer processes
|
||||||
|
writer1 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
writer2 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 2))
|
||||||
|
|
||||||
|
# Start one reader process
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
|
||||||
|
writer1.start()
|
||||||
|
writer2.start()
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run for integrity test duration
|
||||||
|
time.sleep(1.0)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
writer1.join(timeout=3.0)
|
||||||
|
writer2.join(timeout=3.0)
|
||||||
|
reader.join(timeout=3.0)
|
||||||
|
|
||||||
|
# Verify all processes completed successfully
|
||||||
|
assert not writer1.is_alive()
|
||||||
|
assert not writer2.is_alive()
|
||||||
|
assert not reader.is_alive()
|
||||||
|
|
||||||
|
# Verify data integrity by checking messages
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
writer1_messages = [msg for msg in messages if "writer_1_wrote" in msg]
|
||||||
|
writer2_messages = [msg for msg in messages if "writer_2_wrote" in msg]
|
||||||
|
reader_messages = [msg for msg in messages if "reader_1_read" in msg]
|
||||||
|
|
||||||
|
# Verify both writers wrote data
|
||||||
|
assert len(writer1_messages) > 0
|
||||||
|
assert len(writer2_messages) > 0
|
||||||
|
# Verify reader read data
|
||||||
|
assert len(reader_messages) > 0
|
||||||
|
|
||||||
|
# Verify the shared array is in a consistent state
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
final_data = shared_array.read(local_array, flush=False)
|
||||||
|
|
||||||
|
# Should have some data written by the writers
|
||||||
|
assert len(final_data) >= 0 # Could be empty if reader flushed everything
|
||||||
|
# Should not exceed array capacity
|
||||||
|
assert len(final_data) <= shape[0]
|
||||||
|
|
||||||
|
# If there's data, verify it contains the expected writer signatures
|
||||||
|
if len(final_data) > 0:
|
||||||
|
# Data should contain values like 100, 101, 102... (writer 1) or 200, 201, 202... (writer 2)
|
||||||
|
unique_values = np.unique(final_data.flatten())
|
||||||
|
writer1_values = unique_values[(unique_values >= 100) & (unique_values < 200)]
|
||||||
|
writer2_values = unique_values[(unique_values >= 200) & (unique_values < 300)]
|
||||||
|
|
||||||
|
# Should have data from at least one writer
|
||||||
|
assert len(writer1_values) > 0 or len(writer2_values) > 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stress_test_high_frequency_operations():
|
||||||
|
"""Stress test with high frequency read/write operations."""
|
||||||
|
shape = (1000, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
num_writers = 4
|
||||||
|
barrier = multiprocessing.Barrier(num_writers)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
# Start multiple high-frequency writers
|
||||||
|
for i in range(num_writers):
|
||||||
|
writer = Process(
|
||||||
|
target=stress_writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)
|
||||||
|
)
|
||||||
|
processes.append(writer)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
# Let them run for stress test duration
|
||||||
|
time.sleep(0.5)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=3.0)
|
||||||
|
assert not process.is_alive()
|
||||||
|
|
||||||
|
# Verify all writers completed successfully
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
completed_messages = [msg for msg in messages if "completed" in msg]
|
||||||
|
assert len(completed_messages) == num_writers
|
||||||
|
|
||||||
|
# Verify the shared array is in a consistent state
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
final_data = shared_array.read(local_array, flush=False)
|
||||||
|
|
||||||
|
# Should have some data written
|
||||||
|
assert len(final_data) > 0
|
||||||
|
# Should not exceed array capacity
|
||||||
|
assert len(final_data) <= shape[0]
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_readers():
|
||||||
|
"""Test multiple concurrent readers with writers to ensure thread safety."""
|
||||||
|
shape = (200, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
num_readers = 3
|
||||||
|
num_writers = 2
|
||||||
|
barrier = multiprocessing.Barrier(num_readers + num_writers)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
# Start multiple writer processes to generate data
|
||||||
|
for i in range(num_writers):
|
||||||
|
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||||
|
processes.append(writer)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
# Start multiple reader processes
|
||||||
|
for i in range(num_readers):
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||||
|
processes.append(reader)
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run to test concurrent access
|
||||||
|
time.sleep(1.0)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for all processes to complete
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=3.0)
|
||||||
|
assert not process.is_alive()
|
||||||
|
|
||||||
|
# Verify all readers and writers completed
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||||
|
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||||
|
|
||||||
|
# Should have messages from all readers and writers
|
||||||
|
assert len(reader_messages) >= num_readers
|
||||||
|
assert len(writer_messages) >= num_writers
|
||||||
|
|
||||||
|
# Verify different readers generated different messages (proving they ran concurrently)
|
||||||
|
reader_ids = set()
|
||||||
|
for msg in reader_messages:
|
||||||
|
# Extract reader ID from message like "reader_1_read_5_items"
|
||||||
|
parts = msg.split("_")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
reader_ids.add(parts[1])
|
||||||
|
|
||||||
|
assert len(reader_ids) == num_readers # All readers should have participated
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_case_empty_reads():
|
||||||
|
"""Test reading from empty array and after flushes."""
|
||||||
|
shape = (10, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Read from empty array
|
||||||
|
empty_data = shared_array.read(local_array, flush=False)
|
||||||
|
assert len(empty_data) == 0
|
||||||
|
|
||||||
|
# Write some data
|
||||||
|
data = np.ones((3, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
|
||||||
|
# Read with flush
|
||||||
|
read_data = shared_array.read(local_array, flush=True)
|
||||||
|
assert len(read_data) == 3
|
||||||
|
|
||||||
|
# Read again after flush - should be empty
|
||||||
|
empty_again = shared_array.read(local_array, flush=False)
|
||||||
|
assert len(empty_again) == 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_dtypes():
|
||||||
|
"""Test SharedArray with different numpy dtypes."""
|
||||||
|
dtypes_to_test = [np.float32, np.float64, np.int32, np.int16]
|
||||||
|
|
||||||
|
for dtype in dtypes_to_test:
|
||||||
|
shape = (20, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
assert local_array.dtype == dtype
|
||||||
|
|
||||||
|
# Write and read data of this dtype
|
||||||
|
data = np.ones((5, 2), dtype=dtype)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
|
||||||
|
read_data = shared_array.read(local_array, flush=True)
|
||||||
|
assert read_data.dtype == dtype
|
||||||
|
assert len(read_data) == 5
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
+8
-1
@@ -20,7 +20,7 @@ from functools import wraps
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots
|
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||||
from lerobot.utils.device_utils import auto_select_torch_device
|
from lerobot.utils.device_utils import auto_select_torch_device
|
||||||
from lerobot.utils.import_utils import is_package_available
|
from lerobot.utils.import_utils import is_package_available
|
||||||
|
|
||||||
@@ -34,6 +34,10 @@ TEST_CAMERA_TYPES = []
|
|||||||
for camera_type in available_cameras:
|
for camera_type in available_cameras:
|
||||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||||
|
|
||||||
|
TEST_MICROPHONE_TYPES = []
|
||||||
|
for microphone_type in available_microphones:
|
||||||
|
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
|
||||||
|
|
||||||
TEST_MOTOR_TYPES = []
|
TEST_MOTOR_TYPES = []
|
||||||
for motor_type in available_motors:
|
for motor_type in available_motors:
|
||||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||||
@@ -42,6 +46,9 @@ for motor_type in available_motors:
|
|||||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||||
|
|
||||||
|
# Microphone indices used for connecting physical microphones
|
||||||
|
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
|
||||||
|
|
||||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||||
DYNAMIXEL_MOTORS = {
|
DYNAMIXEL_MOTORS = {
|
||||||
"shoulder_pan": [1, "xl430-w250"],
|
"shoulder_pan": [1, "xl430-w250"],
|
||||||
|
|||||||
@@ -37,6 +37,14 @@ def mock_rerun(monkeypatch):
|
|||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
self.value = float(value)
|
self.value = float(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def columns(scalars):
|
||||||
|
return DummyScalarsColumn(scalars)
|
||||||
|
|
||||||
|
class DummyScalarsColumn:
|
||||||
|
def __init__(self, values):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
class DummyImage:
|
class DummyImage:
|
||||||
def __init__(self, arr):
|
def __init__(self, arr):
|
||||||
self.arr = arr
|
self.arr = arr
|
||||||
@@ -47,12 +55,46 @@ def mock_rerun(monkeypatch):
|
|||||||
obj = kwargs.pop("entity")
|
obj = kwargs.pop("entity")
|
||||||
calls.append((key, obj, kwargs))
|
calls.append((key, obj, kwargs))
|
||||||
|
|
||||||
|
def dummy_send_columns(key, indexes, columns, **kwargs):
|
||||||
|
calls.append((key, columns, kwargs))
|
||||||
|
|
||||||
|
def dummy_time_column(timeline, timestamp):
|
||||||
|
return timestamp
|
||||||
|
|
||||||
|
def dummy_set_time(timeline, timestamp):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyTimeSeriesView:
|
||||||
|
def __call__(self, origin, plot_legend=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummySpatial2DView:
|
||||||
|
def __call__(self, origin):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyGrid:
|
||||||
|
def __call__(self, *args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyPlotLegend:
|
||||||
|
def __call__(self, visible=True):
|
||||||
|
return None
|
||||||
|
|
||||||
dummy_rr = SimpleNamespace(
|
dummy_rr = SimpleNamespace(
|
||||||
Scalars=DummyScalar,
|
Scalars=DummyScalar,
|
||||||
Image=DummyImage,
|
Image=DummyImage,
|
||||||
log=dummy_log,
|
log=dummy_log,
|
||||||
|
TimeColumn=dummy_time_column,
|
||||||
|
send_columns=dummy_send_columns,
|
||||||
|
set_time=dummy_set_time,
|
||||||
init=lambda *a, **k: None,
|
init=lambda *a, **k: None,
|
||||||
spawn=lambda *a, **k: None,
|
spawn=lambda *a, **k: None,
|
||||||
|
blueprint=SimpleNamespace(
|
||||||
|
TimeSeriesView=DummyTimeSeriesView,
|
||||||
|
Spatial2DView=DummySpatial2DView,
|
||||||
|
Grid=DummyGrid,
|
||||||
|
PlotLegend=DummyPlotLegend,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inject fake module into sys.modules
|
# Inject fake module into sys.modules
|
||||||
@@ -87,7 +129,7 @@ def _kwargs_for(calls, key):
|
|||||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||||
|
|
||||||
|
|
||||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
|
||||||
vu, calls = mock_rerun
|
vu, calls = mock_rerun
|
||||||
|
|
||||||
# Build EnvTransition dict
|
# Build EnvTransition dict
|
||||||
@@ -95,6 +137,8 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
|||||||
f"{OBS_STATE}.temperature": np.float32(25.0),
|
f"{OBS_STATE}.temperature": np.float32(25.0),
|
||||||
# CHW image should be converted to HWC for rr.Image
|
# CHW image should be converted to HWC for rr.Image
|
||||||
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
||||||
|
# Multiple channels audio data should be split into separate channels and logged as rr.Scalars.columns
|
||||||
|
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||||
}
|
}
|
||||||
act = {
|
act = {
|
||||||
"action.throttle": 0.7,
|
"action.throttle": 0.7,
|
||||||
@@ -117,25 +161,27 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
|||||||
# - action.throttle -> Scalars
|
# - action.throttle -> Scalars
|
||||||
# - action.vector_0, action.vector_1 -> Scalars
|
# - action.vector_0, action.vector_1 -> Scalars
|
||||||
expected_keys = {
|
expected_keys = {
|
||||||
f"{OBS_STATE}.temperature",
|
"data/" + f"{OBS_STATE}.temperature",
|
||||||
"observation.camera",
|
"observation.camera",
|
||||||
"action.throttle",
|
"data/action.throttle",
|
||||||
"action.vector_0",
|
"data/action.vector_0",
|
||||||
"action.vector_1",
|
"data/action.vector_1",
|
||||||
|
"audio/observation.audio_channel_0",
|
||||||
|
"audio/observation.audio_channel_1",
|
||||||
}
|
}
|
||||||
assert set(_keys(calls)) == expected_keys
|
assert set(_keys(calls)) == expected_keys
|
||||||
|
|
||||||
# Check scalar types and values
|
# Check scalar types and values
|
||||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
temp_obj = _obj_for(calls, f"data/{OBS_STATE}.temperature")
|
||||||
assert type(temp_obj).__name__ == "DummyScalar"
|
assert type(temp_obj).__name__ == "DummyScalar"
|
||||||
assert temp_obj.value == pytest.approx(25.0)
|
assert temp_obj.value == pytest.approx(25.0)
|
||||||
|
|
||||||
throttle_obj = _obj_for(calls, "action.throttle")
|
throttle_obj = _obj_for(calls, "data/action.throttle")
|
||||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||||
assert throttle_obj.value == pytest.approx(0.7)
|
assert throttle_obj.value == pytest.approx(0.7)
|
||||||
|
|
||||||
v0 = _obj_for(calls, "action.vector_0")
|
v0 = _obj_for(calls, "data/action.vector_0")
|
||||||
v1 = _obj_for(calls, "action.vector_1")
|
v1 = _obj_for(calls, "data/action.vector_1")
|
||||||
assert type(v0).__name__ == "DummyScalar"
|
assert type(v0).__name__ == "DummyScalar"
|
||||||
assert type(v1).__name__ == "DummyScalar"
|
assert type(v1).__name__ == "DummyScalar"
|
||||||
assert v0.value == pytest.approx(1.0)
|
assert v0.value == pytest.approx(1.0)
|
||||||
@@ -147,6 +193,14 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
|||||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||||
|
|
||||||
|
# Check audio handling: split channels + rr.Scalars.columns
|
||||||
|
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||||
|
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||||
|
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||||
|
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||||
|
assert audio_obj_0.values.shape == (100,)
|
||||||
|
assert audio_obj_1.values.shape == (100,)
|
||||||
|
|
||||||
|
|
||||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||||
vu, calls = mock_rerun
|
vu, calls = mock_rerun
|
||||||
@@ -157,6 +211,8 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
|||||||
"temp": 1.5,
|
"temp": 1.5,
|
||||||
# Already HWC image => should stay as-is
|
# Already HWC image => should stay as-is
|
||||||
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
||||||
|
# Multiple channels audio data should be split into separate channels
|
||||||
|
"audio": np.zeros((100, 2), dtype=np.float32),
|
||||||
"none": None, # should be skipped
|
"none": None, # should be skipped
|
||||||
}
|
}
|
||||||
act_plain = {
|
act_plain = {
|
||||||
@@ -170,22 +226,24 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
|||||||
|
|
||||||
# Expected keys with auto-prefixes
|
# Expected keys with auto-prefixes
|
||||||
expected = {
|
expected = {
|
||||||
"observation.temp",
|
"data/observation.temp",
|
||||||
"observation.img",
|
"observation.img",
|
||||||
"action.throttle",
|
"data/action.throttle",
|
||||||
"action.vec_0",
|
"data/action.vec_0",
|
||||||
"action.vec_1",
|
"data/action.vec_1",
|
||||||
"action.vec_2",
|
"data/action.vec_2",
|
||||||
|
"audio/observation.audio_channel_0",
|
||||||
|
"audio/observation.audio_channel_1",
|
||||||
}
|
}
|
||||||
logged = set(_keys(calls))
|
logged = set(_keys(calls))
|
||||||
assert logged == expected
|
assert logged == expected
|
||||||
|
|
||||||
# Scalars
|
# Scalars
|
||||||
t = _obj_for(calls, "observation.temp")
|
t = _obj_for(calls, "data/observation.temp")
|
||||||
assert type(t).__name__ == "DummyScalar"
|
assert type(t).__name__ == "DummyScalar"
|
||||||
assert t.value == pytest.approx(1.5)
|
assert t.value == pytest.approx(1.5)
|
||||||
|
|
||||||
throttle = _obj_for(calls, "action.throttle")
|
throttle = _obj_for(calls, "data/action.throttle")
|
||||||
assert type(throttle).__name__ == "DummyScalar"
|
assert type(throttle).__name__ == "DummyScalar"
|
||||||
assert throttle.value == pytest.approx(0.3)
|
assert throttle.value == pytest.approx(0.3)
|
||||||
|
|
||||||
@@ -197,25 +255,39 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
|||||||
|
|
||||||
# Vectors
|
# Vectors
|
||||||
for i, val in enumerate([9, 8, 7]):
|
for i, val in enumerate([9, 8, 7]):
|
||||||
o = _obj_for(calls, f"action.vec_{i}")
|
o = _obj_for(calls, f"data/action.vec_{i}")
|
||||||
assert type(o).__name__ == "DummyScalar"
|
assert type(o).__name__ == "DummyScalar"
|
||||||
assert o.value == pytest.approx(val)
|
assert o.value == pytest.approx(val)
|
||||||
|
|
||||||
|
# Audio
|
||||||
|
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||||
|
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||||
|
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||||
|
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||||
|
assert audio_obj_0.values.shape == (100,)
|
||||||
|
assert audio_obj_1.values.shape == (100,)
|
||||||
|
|
||||||
|
|
||||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||||
vu, calls = mock_rerun
|
vu, calls = mock_rerun
|
||||||
|
|
||||||
vu.log_rerun_data(
|
vu.log_rerun_data(
|
||||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
observation={
|
||||||
|
"observation.temp": 10.0,
|
||||||
|
"observation.gray": np.zeros((8, 8, 1), dtype=np.uint8),
|
||||||
|
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||||
|
},
|
||||||
action={"action.a": 1.0},
|
action={"action.a": 1.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = set(_keys(calls))
|
keys = set(_keys(calls))
|
||||||
assert "observation.temp" in keys
|
assert "data/observation.temp" in keys
|
||||||
assert "observation.gray" in keys
|
assert "observation.gray" in keys
|
||||||
assert "action.a" in keys
|
assert "data/action.a" in keys
|
||||||
|
assert "audio/observation.audio_channel_0" in keys
|
||||||
|
assert "audio/observation.audio_channel_1" in keys
|
||||||
|
|
||||||
temp = _obj_for(calls, "observation.temp")
|
temp = _obj_for(calls, "data/observation.temp")
|
||||||
assert type(temp).__name__ == "DummyScalar"
|
assert type(temp).__name__ == "DummyScalar"
|
||||||
assert temp.value == pytest.approx(10.0)
|
assert temp.value == pytest.approx(10.0)
|
||||||
|
|
||||||
@@ -224,6 +296,13 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
|||||||
assert img.arr.shape == (8, 8, 1) # remains HWC
|
assert img.arr.shape == (8, 8, 1) # remains HWC
|
||||||
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
||||||
|
|
||||||
a = _obj_for(calls, "action.a")
|
a = _obj_for(calls, "data/action.a")
|
||||||
assert type(a).__name__ == "DummyScalar"
|
assert type(a).__name__ == "DummyScalar"
|
||||||
assert a.value == pytest.approx(1.0)
|
assert a.value == pytest.approx(1.0)
|
||||||
|
|
||||||
|
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||||
|
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||||
|
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||||
|
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||||
|
assert audio_obj_0.values.shape == (100,)
|
||||||
|
assert audio_obj_1.values.shape == (100,)
|
||||||
|
|||||||
Reference in New Issue
Block a user