mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
105 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 764404a27e | |||
| 8b9451b585 | |||
| ab4903e752 | |||
| 538cea6dbc | |||
| 5cd3572713 | |||
| 3399513e5e | |||
| 32fc4015ee | |||
| cc72c813bf | |||
| 606f31a86e | |||
| 4933c9dcc7 | |||
| 7e25385024 | |||
| cc70bff74d | |||
| 9f50913b9c | |||
| 4eb7694d47 | |||
| edb5559b5b | |||
| 552ec76195 | |||
| e75340b473 | |||
| 2a4c223ec7 | |||
| 1ee4d84f07 | |||
| 6bd40ca219 | |||
| b879cf3d04 | |||
| bd9e5c1a64 | |||
| 9271a0c900 | |||
| af2f044f5a | |||
| 0caba222ef | |||
| 6d73f5bfe6 | |||
| ef8f40c21b | |||
| 0232879245 | |||
| 2726b4e865 | |||
| e126d35249 | |||
| d7ae8cd699 | |||
| 2f96d8bf76 | |||
| e129c71b4f | |||
| a02d70389d | |||
| 0d4922ce49 | |||
| eaeff78924 | |||
| e2f3982e2c | |||
| a73ac2bdbb | |||
| 95de732e55 | |||
| b2383236ca | |||
| 4b98cc25c8 | |||
| 90780c4de8 | |||
| 6f6e046c53 | |||
| 8cd64eaad1 | |||
| e620395416 | |||
| 0fbcbcdb2e | |||
| 674f5dfd75 | |||
| 7d430c8067 | |||
| 5f114c1d74 | |||
| ad01ef19f4 | |||
| 59e8f4572c | |||
| 97e91698fb | |||
| af0294198a | |||
| 421fdcce96 | |||
| bb63ad9715 | |||
| 3c90a79c57 | |||
| 8e29c530ed | |||
| b573b7a052 | |||
| 926184110b | |||
| bf8ede852d | |||
| f73db4394b | |||
| bff91f9927 | |||
| 6d726266fd | |||
| 2962330bb1 | |||
| 067993bb11 | |||
| e4dd00c8f5 | |||
| e714ff22e2 | |||
| 3bbd161cfd | |||
| 6d7be63f59 | |||
| b9d0dfb9a2 | |||
| dce483060f | |||
| c32b9182d9 | |||
| a4d4ef0e7f | |||
| 9a5c96b2b1 | |||
| 0a6ca58299 | |||
| 688195fc46 | |||
| 99eb0bbafc | |||
| 16de8b3f19 | |||
| 580008663b | |||
| 52c424c5eb | |||
| 836195e59c | |||
| be09a59e05 | |||
| 373a169bd2 | |||
| 00536c6c5b | |||
| cdd3a859ef | |||
| 5276fc0d6f | |||
| 6a2882f978 | |||
| 8874547353 | |||
| 2864caad80 | |||
| d998660aa1 | |||
| 7e5f3b35e9 | |||
| 01fea7c407 | |||
| 13bfee1aa4 | |||
| 79688a09f2 | |||
| b2ff219624 | |||
| 66929c5935 | |||
| 5286ef8439 | |||
| fe068df711 | |||
| da41646073 | |||
| 46e19ae579 | |||
| 77dc49b3a3 | |||
| 33910673ec | |||
| 19dce78457 | |||
| 112b2d173a | |||
| b825880c40 |
@@ -20,8 +20,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
# schedule:
|
||||
# - cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
# Security Policy
|
||||
|
||||
## Project Status & Philosophy
|
||||
|
||||
`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasn’t been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues.
|
||||
|
||||
Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab.
|
||||
|
||||
The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
|
||||
#### Hugging Face Security Team
|
||||
|
||||
Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps.
|
||||
|
||||
#### Open Source Disclosures
|
||||
|
||||
If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch).
|
||||
|
||||
| Version | Supported |
|
||||
| -------- | --------- |
|
||||
| Latest | ✅ |
|
||||
| < Latest | ❌ |
|
||||
|
||||
## Secure Usage Guidelines
|
||||
|
||||
`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe.
|
||||
|
||||
### Remote Artefacts (Weights & Policies)
|
||||
|
||||
Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||
|
||||
`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots.
|
||||
|
||||
To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files.
|
||||
|
||||
### Remote Code
|
||||
|
||||
Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code.
|
||||
|
||||
Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository.
|
||||
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
microphones_configs: dict[str, MicrophoneConfig],
|
||||
audio_chunks_number: int,
|
||||
audio_chunks_duration: float,
|
||||
repetitions: int,
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/audio_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
microphones = make_microphones_from_configs(microphones_configs)
|
||||
|
||||
# Connect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
all_audio_chunks = []
|
||||
for i in range(repetitions):
|
||||
print(f"Repetition {i + 1}/{repetitions}...")
|
||||
|
||||
# Create audio chunks
|
||||
audio_chunks = {}
|
||||
for microphone_key in microphones:
|
||||
audio_chunks.update({microphone_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
microphones,
|
||||
output_files=[
|
||||
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
|
||||
],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
for j in range(audio_chunks_number):
|
||||
precise_sleep(audio_chunks_duration)
|
||||
|
||||
for microphone_key, microphone in microphones.items():
|
||||
audio_chunk = microphone.read()
|
||||
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
|
||||
audio_chunks[microphone_key].append(audio_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone_key in microphones:
|
||||
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
|
||||
|
||||
all_audio_chunks.append(audio_chunks)
|
||||
|
||||
# Disconnect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.disconnect()
|
||||
|
||||
# Compute statistics
|
||||
cmap = plt.get_cmap("tab10")
|
||||
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
|
||||
chunk_length = np.zeros((repetitions, len(microphones)))
|
||||
record_length = np.zeros((repetitions, len(microphones)))
|
||||
for i in range(repetitions):
|
||||
for j, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
# Get recorded audio chunks
|
||||
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
|
||||
|
||||
# Load recorded file
|
||||
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
|
||||
if recorded_data.ndim == 1:
|
||||
recorded_data = np.expand_dims(recorded_data, axis=1)
|
||||
|
||||
record_length[i, j] = recorded_data.shape[0]
|
||||
chunk_length[i, j] = recorded_audio_chunks.shape[0]
|
||||
|
||||
for k, (chunk_data, record_data) in enumerate(
|
||||
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
|
||||
):
|
||||
# Plot audio chunks and recorded data
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(chunk_data)) / microphone.sample_rate,
|
||||
chunk_data,
|
||||
label=f"audio chunks - channel {k}",
|
||||
color=cmap(2 * k),
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
record_data,
|
||||
label=f"recorded data - channel {k}",
|
||||
linestyle="dashed",
|
||||
color=cmap(2 * k + 1),
|
||||
)
|
||||
|
||||
# Plot absolute difference (errors should be located at the end of the recordings)
|
||||
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
|
||||
chunk_data = np.append(
|
||||
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
else:
|
||||
record_data = np.append(
|
||||
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
np.abs(chunk_data - record_data),
|
||||
label=f"differences - channel {k}",
|
||||
color="red",
|
||||
linestyle="dotted",
|
||||
)
|
||||
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
|
||||
ax[i, j].legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
# Print statistics
|
||||
differences = record_length - chunk_length
|
||||
for i, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
print(
|
||||
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(
|
||||
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
|
||||
print(
|
||||
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--microphones_indices",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_sample_rate",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument("--audio_chunks_number", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--audio_chunks_duration",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetitions",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["microphones_configs"] = {}
|
||||
for index, sample_rate, channels in zip(
|
||||
args["microphones_indices"],
|
||||
args["microphones_sample_rate"],
|
||||
args["microphones_channels"],
|
||||
strict=False,
|
||||
):
|
||||
microphone_config = PortAudioMicrophoneConfig(
|
||||
microphone_index=index,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
|
||||
args.pop("microphones_indices")
|
||||
args.pop("microphones_sample_rate")
|
||||
args.pop("microphones_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from lerobot.microphones.anyskin import AnyskinSensorConfig
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
sensors_configs: dict[str, MicrophoneConfig],
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/tactile_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
sensors = make_microphones_from_configs(sensors_configs)
|
||||
|
||||
# Connect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.connect()
|
||||
|
||||
# Create audio chunks
|
||||
data_chunks = {}
|
||||
for sensor_key in sensors:
|
||||
data_chunks.update({sensor_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
sensors,
|
||||
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
precise_sleep(10.0)
|
||||
|
||||
for sensor_key, sensor in sensors.items():
|
||||
data_chunk = sensor.read()
|
||||
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
|
||||
data_chunks[sensor_key].append(data_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(sensors)
|
||||
|
||||
for sensor_key in sensors:
|
||||
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
|
||||
|
||||
# Disconnect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.disconnect()
|
||||
|
||||
for sensor_key in sensors:
|
||||
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
|
||||
print(f"{sensor_key} - samples {data.shape[0]}")
|
||||
print(f"{sensor_key} - sample rate {sample_rate}")
|
||||
print(f"{sensor_key} - data {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sensors_ports",
|
||||
type=str,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_baud_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_sample_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["sensors_configs"] = {}
|
||||
for port, baud_rate, sample_rate, channels in zip(
|
||||
args["sensors_ports"],
|
||||
args["sensors_baud_rate"],
|
||||
args["sensors_sample_rate"],
|
||||
args["sensors_channels"],
|
||||
strict=False,
|
||||
):
|
||||
channels = [1, 2, 3, 4, 5]
|
||||
sensor_config = AnyskinSensorConfig(
|
||||
sensor_port=port,
|
||||
baud_rate=baud_rate,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
|
||||
args.pop("sensors_ports")
|
||||
args.pop("sensors_baud_rate")
|
||||
args.pop("sensors_sample_rate")
|
||||
args.pop("sensors_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
@@ -124,16 +124,23 @@ lerobot-edit-dataset \
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
|
||||
# For memory-constrained systems, users can now specify limits:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.max_episodes_per_batch 50 \
|
||||
--operation.max_frames_per_batch 10000
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
@@ -43,12 +43,13 @@ def main():
|
||||
keyboard.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -69,7 +70,7 @@ def main():
|
||||
_ = robot.send_action(action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -89,12 +89,13 @@ def main():
|
||||
teleop_device.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -111,7 +112,7 @@ def main():
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -94,9 +94,10 @@ def main():
|
||||
leader.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -116,7 +117,9 @@ def main():
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
log_rerun_data(
|
||||
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
|
||||
)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -175,6 +176,7 @@ all = [
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[audio]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
"lerobot[video_benchmark]",
|
||||
|
||||
@@ -29,6 +29,7 @@ Example:
|
||||
print(lerobot.available_policies_per_env)
|
||||
print(lerobot.available_robots)
|
||||
print(lerobot.available_cameras)
|
||||
print(lerobot.available_microphones)
|
||||
print(lerobot.available_motors)
|
||||
```
|
||||
|
||||
@@ -174,6 +175,13 @@ available_cameras = [
|
||||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available microphones from `lerobot/microphones`
|
||||
available_microphones = [
|
||||
"portaudio",
|
||||
"touchlab",
|
||||
"anyskin",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
|
||||
@@ -47,6 +47,9 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
|
||||
@@ -151,6 +151,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||
if not self.input_features:
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
if not self.output_features:
|
||||
|
||||
@@ -20,6 +20,7 @@ from enum import Enum
|
||||
class FeatureType(str, Enum):
|
||||
STATE = "STATE"
|
||||
VISUAL = "VISUAL"
|
||||
AUDIO = "AUDIO"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
@@ -19,12 +19,15 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -32,6 +35,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
@@ -39,7 +43,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -108,6 +112,7 @@ def update_meta_data(
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
@@ -120,7 +125,7 @@ def update_meta_data(
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
|
||||
audios_idx: Dictionary containing current audio indices and timestamps.
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
@@ -178,6 +183,36 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
# Store original audio file indices before updating
|
||||
orig_chunk_col = f"audio/{key}/chunk_index"
|
||||
orig_file_col = f"audio/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = audio_idx["chunk"]
|
||||
df[orig_file_col] = audio_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = audio_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"audio/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"audio/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"audio/{key}/from_timestamp"] = (
|
||||
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
|
||||
)
|
||||
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
@@ -192,6 +227,7 @@ def aggregate_datasets(
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
audio_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -209,6 +245,7 @@ def aggregate_datasets(
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
@@ -217,6 +254,8 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_files_size_in_mb is None:
|
||||
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
@@ -229,6 +268,7 @@ def aggregate_datasets(
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
@@ -240,6 +280,7 @@ def aggregate_datasets(
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
audio_files_size_in_mb=audio_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -251,14 +292,18 @@ def aggregate_datasets(
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
audios_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
@@ -326,7 +371,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="video")
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
@@ -365,7 +410,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
@@ -380,6 +425,101 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
|
||||
"""Aggregates audio files from a source dataset into the destination dataset.
|
||||
|
||||
Handles audio file concatenation and rotation based on file size limits.
|
||||
Creates new audio files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
audio_idx: Dictionary tracking audio chunk and file indices.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated audio_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in audios_idx:
|
||||
audios_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
audios_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"audio/{key}/chunk_index"],
|
||||
src_meta.episodes[f"audio/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||
|
||||
chunk_idx = audio_idx["chunk"]
|
||||
file_idx = audio_idx["file"]
|
||||
current_offset = audio_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="audio")
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= audio_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
audios_idx[key]["chunk"] = chunk_idx
|
||||
audios_idx[key]["file"] = file_idx
|
||||
|
||||
return audios_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
@@ -402,12 +542,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
}
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
contains_images = len(dst_meta.image_keys) > 0
|
||||
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read source data to preserve image format
|
||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
||||
df = src_ds.to_pandas()
|
||||
else:
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
@@ -417,14 +566,15 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
data_files_size_in_mb,
|
||||
chunk_size,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
@@ -436,6 +586,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
audios_idx: Dictionary tracking audio indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
@@ -459,6 +610,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
@@ -475,7 +627,8 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
for k in audios_idx:
|
||||
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
|
||||
return meta_idx
|
||||
|
||||
|
||||
@@ -488,6 +641,7 @@ def append_or_create_parquet_file(
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
):
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -503,6 +657,7 @@ def append_or_create_parquet_file(
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
@@ -512,7 +667,7 @@ def append_or_create_parquet_file(
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
@@ -527,12 +682,17 @@ def append_or_create_parquet_file(
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read existing data to preserve image format
|
||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
||||
existing_df = existing_ds.to_pandas()
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchcodec
|
||||
from numpy import ceil
|
||||
|
||||
CHANNELS_LAYOUTS_MAPPING = {
|
||||
1: "mono",
|
||||
2: "stereo",
|
||||
3: "2.1",
|
||||
4: "3.1",
|
||||
5: "4.1",
|
||||
6: "5.1",
|
||||
7: "6.1",
|
||||
8: "7.1",
|
||||
16: "hexadecagonal",
|
||||
24: "22.2",
|
||||
}
|
||||
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
backend: str | None = "torchcodec",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes audio using the specified backend.
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file.
|
||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||
duration (float): Duration of the audio chunks in seconds.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded audio chunks.
|
||||
|
||||
Currently supports torchaudio.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
|
||||
elif backend == "torchaudio":
|
||||
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_audio_torchcodec(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
|
||||
audio_sample_rate = audio_decoder.metadata.sample_rate
|
||||
audio_channels = audio_decoder.metadata.num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
current_audio_chunk = audio_decoder.get_samples_played_in_range(
|
||||
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
||||
)
|
||||
|
||||
current_audio_chunk_data = current_audio_chunk.data
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def decode_audio_torchaudio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
# Should not happen, but just in case
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Remove the superfluous last samples of the audio chunk
|
||||
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
bit_rate: int | None = None,
|
||||
sample_rate: int | None = None,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encodes an audio file using ffmpeg."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
# Open input file
|
||||
with av.open(str(input_path), "r") as input:
|
||||
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
|
||||
|
||||
# Define sub-sampling options
|
||||
if sample_rate is None:
|
||||
sample_rate = input_stream.rate
|
||||
|
||||
# Create and open output file (overwrite by default)
|
||||
with av.open(str(output_path), "w") as output:
|
||||
output_stream = output.add_stream(
|
||||
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
|
||||
)
|
||||
|
||||
if bit_rate is not None:
|
||||
output_stream.bit_rate = bit_rate
|
||||
|
||||
# Loop through input WAV packets and encode them
|
||||
for input_frame in input.decode(
|
||||
input_stream
|
||||
): # This step handles both demuxing and decoding under the hood
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Reset logging level
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_path.exists():
|
||||
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
@@ -245,6 +245,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
@@ -512,6 +526,13 @@ def compute_episode_stats(
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
|
||||
@@ -26,6 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -51,7 +52,8 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -1083,3 +1085,561 @@ def _copy_episodes_metadata_and_stats(
|
||||
else:
|
||||
if src_dataset.meta.stats:
|
||||
write_stats(src_dataset.meta.stats, dst_meta.root)
|
||||
|
||||
|
||||
def _save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def _save_batch_episodes_images(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
num_workers: int = 4,
|
||||
) -> list[float]:
|
||||
"""Save images from multiple episodes to disk for batch video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_indices: List of episode indices to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
|
||||
Returns:
|
||||
List of episode durations in seconds
|
||||
"""
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
frame_idx = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
episode_durations.append(episode_length / dataset.fps)
|
||||
|
||||
# Get episode images
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Save images
|
||||
items = list(enumerate(episode_dataset))
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
frame_idx += episode_length
|
||||
|
||||
return episode_durations
|
||||
|
||||
|
||||
def _iter_episode_batches(
|
||||
episode_indices: list[int],
|
||||
episode_lengths: dict[int, int],
|
||||
size_per_frame_mb: float,
|
||||
video_file_size_limit: float,
|
||||
max_episodes: int | None,
|
||||
max_frames: int | None,
|
||||
):
|
||||
"""Generator that yields batches of episode indices for video encoding.
|
||||
|
||||
Groups episodes into batches that respect size and memory constraints:
|
||||
- Stays under video file size limit
|
||||
- Respects maximum episodes per batch (if specified)
|
||||
- Respects maximum frames per batch (if specified)
|
||||
|
||||
Args:
|
||||
episode_indices: List of episode indices to batch
|
||||
episode_lengths: Dictionary mapping episode index to episode length
|
||||
size_per_frame_mb: Estimated size per frame in MB
|
||||
video_file_size_limit: Maximum video file size in MB
|
||||
max_episodes: Maximum number of episodes per batch (None = no limit)
|
||||
max_frames: Maximum number of frames per batch (None = no limit)
|
||||
|
||||
Yields:
|
||||
List of episode indices for each batch
|
||||
"""
|
||||
batch_episodes = []
|
||||
estimated_size = 0.0
|
||||
total_frames = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep_length = episode_lengths[ep_idx]
|
||||
ep_estimated_size = ep_length * size_per_frame_mb
|
||||
|
||||
# we check if adding this episode would exceed any constraint
|
||||
would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit
|
||||
would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes
|
||||
would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames
|
||||
|
||||
if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames):
|
||||
# yield current batch before adding this episode
|
||||
yield batch_episodes
|
||||
# start a new batch with current episode
|
||||
batch_episodes = [ep_idx]
|
||||
estimated_size = ep_estimated_size
|
||||
total_frames = ep_length
|
||||
else:
|
||||
# add to current batch
|
||||
batch_episodes.append(ep_idx)
|
||||
estimated_size += ep_estimated_size
|
||||
total_frames += ep_length
|
||||
|
||||
# yield final batch if not empty
|
||||
if batch_episodes:
|
||||
yield batch_episodes
|
||||
|
||||
|
||||
def _estimate_frame_size_via_calibration(
|
||||
dataset: LeRobotDataset,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
|
||||
Encodes a representative sample of frames using the exact codec parameters
|
||||
to measure actual compression ratio, which is more accurate than heuristics.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images.
|
||||
img_key: Image key to calibrate (e.g., "observation.images.top").
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
Estimated size in MB per frame based on actual encoding.
|
||||
"""
|
||||
calibration_dir = temp_dir / "calibration" / img_key
|
||||
calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Select a representative episode (prefer middle episode if available)
|
||||
calibration_ep_idx = episode_indices[len(episode_indices) // 2]
|
||||
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
|
||||
# Use up to num_calibration_frames from this episode
|
||||
num_frames = min(num_calibration_frames, episode_length)
|
||||
|
||||
# Get frames from dataset
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames
|
||||
for i, idx in enumerate(sample_indices):
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
encode_video_frames(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Measure actual compressed size
|
||||
video_size_bytes = calibration_video_path.stat().st_size
|
||||
video_size_mb = video_size_bytes / BYTES_PER_MIB
|
||||
size_per_frame_mb = video_size_mb / num_frames
|
||||
|
||||
logging.info(
|
||||
f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB "
|
||||
f"= {size_per_frame_mb:.4f} MB/frame for {img_key}"
|
||||
)
|
||||
|
||||
return size_per_frame_mb
|
||||
|
||||
finally:
|
||||
# Clean up calibration files
|
||||
if calibration_dir.exists():
|
||||
shutil.rmtree(calibration_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
# Video conversion constants
|
||||
BYTES_PER_KIB = 1024
|
||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
max_frames_per_batch: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-to-video dataset.
|
||||
|
||||
Creates a new LeRobotDataset with images encoded as videos, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process all episodes and batch encode videos
|
||||
# Use dictionary for O(1) episode metadata lookups instead of O(n) linear search
|
||||
all_episode_metadata = {}
|
||||
fps = int(dataset.fps)
|
||||
|
||||
try:
|
||||
# Build episode metadata entries first
|
||||
logging.info("Building episode metadata...")
|
||||
cumulative_frame_idx = 0
|
||||
for ep_idx in episode_indices:
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
ep_length = src_episode["length"]
|
||||
ep_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": ep_length,
|
||||
"dataset_from_index": cumulative_frame_idx,
|
||||
"dataset_to_index": cumulative_frame_idx + ep_length,
|
||||
}
|
||||
if "data/chunk_index" in src_episode:
|
||||
ep_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
ep_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
all_episode_metadata[ep_idx] = ep_meta
|
||||
cumulative_frame_idx += ep_length
|
||||
|
||||
# Process each camera and batch encode multiple episodes together
|
||||
video_file_size_limit = new_meta.video_files_size_in_mb
|
||||
|
||||
# Pre-compute episode lengths for batching
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
dataset=dataset,
|
||||
img_key=img_key,
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
chunk_idx, file_idx = 0, 0
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Process episodes in batches to stay under size limit
|
||||
for batch_episodes in _iter_episode_batches(
|
||||
episode_indices=episode_indices,
|
||||
episode_lengths=episode_lengths,
|
||||
size_per_frame_mb=size_per_frame_mb,
|
||||
video_file_size_limit=video_file_size_limit,
|
||||
max_episodes=max_episodes_per_batch,
|
||||
max_frames=max_frames_per_batch,
|
||||
):
|
||||
total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes)
|
||||
logging.info(
|
||||
f" Encoding batch of {len(batch_episodes)} episodes "
|
||||
f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames"
|
||||
)
|
||||
|
||||
# Save images for all episodes in this batch
|
||||
imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key
|
||||
episode_durations = _save_batch_episodes_images(
|
||||
dataset=dataset,
|
||||
imgs_dir=imgs_dir,
|
||||
img_key=img_key,
|
||||
episode_indices=batch_episodes,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
# Encode all batched episodes into single video
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Update metadata for each episode in the batch
|
||||
for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True):
|
||||
from_timestamp = cumulative_timestamp
|
||||
to_timestamp = cumulative_timestamp + duration
|
||||
cumulative_timestamp = to_timestamp
|
||||
|
||||
# Find episode metadata entry and add video metadata (O(1) dictionary lookup)
|
||||
ep_meta = all_episode_metadata[ep_idx]
|
||||
ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx
|
||||
ep_meta[f"videos/{img_key}/file_index"] = file_idx
|
||||
ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp
|
||||
ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp
|
||||
|
||||
# Move to next video file for next batch
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size)
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(list(all_episode_metadata.values()))
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
@@ -33,12 +33,16 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.datasets.audio_utils import decode_audio, encode_audio, get_audio_info
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
INFO_PATH,
|
||||
_validate_feature_names,
|
||||
check_delta_timestamps,
|
||||
@@ -68,13 +72,15 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
concatenate_media_files,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_media_duration_in_s,
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.microphones import Microphone
|
||||
from lerobot.microphones.utils import async_microphones_start_recording
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
@@ -214,6 +220,19 @@ class LeRobotDatasetMetadata:
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
|
||||
file_idx = ep[f"audio/{audio_key}/file_index"]
|
||||
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
@@ -224,6 +243,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def audio_path(self) -> str | None:
|
||||
"""Formattable string for the audio files."""
|
||||
return self.info["audio_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
@@ -254,6 +278,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -294,6 +323,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def audio_files_size_in_mb(self) -> int:
|
||||
"""Max size of audio file in mega bytes."""
|
||||
return self.info["audio_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
@@ -435,11 +469,27 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_audio_info(self, audio_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if audio_key is not None and audio_key not in self.audio_keys:
|
||||
raise ValueError(f"Audio key {audio_key} not found in dataset")
|
||||
|
||||
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
|
||||
for key in audio_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
@@ -451,6 +501,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
@@ -467,6 +518,11 @@ class LeRobotDatasetMetadata:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
if audio_files_size_in_mb is not None:
|
||||
if audio_files_size_in_mb <= 0:
|
||||
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
|
||||
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -474,12 +530,13 @@ class LeRobotDatasetMetadata:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
"audio_files_size_in_mb": self.audio_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
@@ -506,6 +563,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
@@ -529,6 +587,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
audio_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
@@ -564,7 +623,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
@@ -598,6 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
task-conditioned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
|
||||
|
||||
A typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
@@ -623,19 +685,37 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.parquet
|
||||
└── videos
|
||||
├── observation.images.laptop
|
||||
├── videos
|
||||
│ ├── observation.images.laptop
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ ├── observation.images.phone
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
└── audio
|
||||
├── observation.audio.laptop
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── observation.images.phone
|
||||
├── observation.audio.phone
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
@@ -675,8 +755,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
@@ -694,6 +776,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.audio_backend = (
|
||||
audio_backend if audio_backend else "torchcodec"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
@@ -766,6 +851,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
@@ -774,6 +860,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
@@ -828,7 +916,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download(self, download_videos: bool = True) -> None:
|
||||
def download(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
@@ -836,8 +924,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
@@ -852,6 +944,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self.meta.audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
@@ -864,7 +965,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
@@ -892,6 +993,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
# Check if all required audio files exist
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
if not audio_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
@@ -935,17 +1044,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps.
|
||||
|
||||
Args:
|
||||
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
|
||||
ep_idx: The episode index.
|
||||
|
||||
Returns:
|
||||
A tuple of (query_indices, padding) where:
|
||||
- query_indices: Dict mapping keys to lists of absolute indices to query
|
||||
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -957,7 +1079,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.meta.video_keys:
|
||||
for key in self.meta.video_keys + self.meta.audio_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
if self._absolute_to_relative_idx is not None:
|
||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||
@@ -972,7 +1094,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""
|
||||
Query dataset for indices across keys, skipping video keys.
|
||||
Query dataset for indices across keys, skipping video keys and audio keys.
|
||||
|
||||
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||
|
||||
@@ -984,7 +1106,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
if key in self.meta.video_keys or key in self.meta.audio_keys:
|
||||
continue
|
||||
# Map absolute indices to relative indices if needed
|
||||
relative_indices = (
|
||||
@@ -1019,6 +1141,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(
|
||||
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
|
||||
audio_chunk = decode_audio(
|
||||
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
|
||||
)
|
||||
item[audio_key] = audio_chunk.squeeze(0)
|
||||
|
||||
return item
|
||||
|
||||
def _ensure_hf_dataset_loaded(self):
|
||||
"""Lazy load the HF dataset only when needed for reading."""
|
||||
if self._lazy_loading or self.hf_dataset is None:
|
||||
@@ -1037,20 +1181,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
# Use the absolute index from the dataset for delta timestamp calculations
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||
item = {**item, **video_frames, **audio_chunks}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
@@ -1098,6 +1245,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
@@ -1150,11 +1301,43 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self.features[key]["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||
"""
|
||||
|
||||
audio_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
|
||||
microphone.start_recording(output_file=audio_file)
|
||||
|
||||
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
|
||||
"""
|
||||
|
||||
output_files = []
|
||||
for microphone_key in microphones:
|
||||
output_files.append(
|
||||
self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
|
||||
)
|
||||
|
||||
async_microphones_start_recording(microphones, output_files)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
@@ -1198,6 +1381,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
@@ -1206,9 +1395,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
has_audio_keys = len(self.meta.audio_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
if (has_video_keys or has_audio_keys) and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1245,21 +1435,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# TODO(Caroline): add parallel encoding for audio as well
|
||||
for audio_key in self.meta.audio_keys:
|
||||
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
if (has_video_keys or has_audio_keys) and use_batched_encoding:
|
||||
# Check if we should trigger batch encoding
|
||||
self.episodes_since_last_encoding += 1
|
||||
if self.episodes_since_last_encoding == self.batch_encoding_size:
|
||||
start_ep = self.num_episodes - self.batch_encoding_size
|
||||
end_ep = self.num_episodes
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_video_keys:
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_audio_keys:
|
||||
self._batch_save_episode_audio(start_ep, end_ep)
|
||||
self.episodes_since_last_encoding = 0
|
||||
|
||||
if not episode_data:
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
self.clear_episode_buffer(
|
||||
delete_images=len(self.meta.image_keys) > 0, delete_audio=len(self.meta.audio_keys) > 0
|
||||
)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
@@ -1310,7 +1509,70 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
# Save the current episode's audio metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(video_ep_df)
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
Batch save audio for multiple episodes.
|
||||
|
||||
Args:
|
||||
start_episode: Starting episode index (inclusive)
|
||||
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
|
||||
"""
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[start_episode]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding audio for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
or self.meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||
):
|
||||
# The current episode is in a new chunk or file.
|
||||
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
# Load new episode dataframe
|
||||
chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[ep_idx]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
# Save the current episode's video metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
@@ -1421,7 +1683,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
@@ -1467,7 +1729,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest video file
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
@@ -1489,7 +1751,79 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
|
||||
# Encode episode audio into a temporary audio file
|
||||
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
or self.meta.latest_episode is None
|
||||
or f"audio/{audio_key}/chunk_index" not in self.meta.latest_episode
|
||||
):
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
old_chunk_idx = self.meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
|
||||
old_file_idx = self.meta.episodes[-1][f"audio/{audio_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self.meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated audio file using latest_episode
|
||||
latest_ep = self.meta.latest_episode
|
||||
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
|
||||
|
||||
latest_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.audio_files_size_in_mb:
|
||||
# Move temporary episode audio to a new audio file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest audio file
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
|
||||
if episode_index == 0:
|
||||
self.meta.update_audio_info(audio_key)
|
||||
write_info(self.meta.info, self.meta.root) # ensure audio info always written properly
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx,
|
||||
f"audio/{audio_key}/file_index": file_idx,
|
||||
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1498,11 +1832,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.camera_keys:
|
||||
for cam_key in self.meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Clean up audio files for the current episode buffer
|
||||
if delete_audio:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
if audio_file.is_file():
|
||||
audio_file.unlink()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
@@ -1539,6 +1883,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
|
||||
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert raw audio files into m4a audio files.
|
||||
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since audio encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{audio_key}_{episode_index:03d}.m4a"
|
||||
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
encode_audio(raw_audio_file, temp_path, overwrite=True)
|
||||
raw_audio_file.unlink()
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -1552,6 +1908,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
) -> "LeRobotDataset":
|
||||
@@ -1596,6 +1953,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
obj.audio_backend = (
|
||||
audio_backend if audio_backend is not None else "torchcodec"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1616,6 +1976,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
@@ -1633,6 +1994,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
audio_backend=audio_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
@@ -36,6 +36,7 @@ from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from soundfile import read
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -50,6 +51,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
@@ -57,13 +59,19 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
AUDIO_DIR = "audio"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||
|
||||
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
@@ -408,6 +416,16 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
|
||||
# Fill missing channel dimension when loading mono audio data
|
||||
if audio_data.ndim == 1:
|
||||
audio_data = np.expand_dims(audio_data, axis=1)
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
@@ -576,7 +594,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -639,7 +657,12 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
cam_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
mic_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
|
||||
}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
@@ -662,6 +685,14 @@ def hw_to_dataset_features(
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
for key, parameters in mic_fts.items():
|
||||
features[f"{prefix}.audio.{key}"] = {
|
||||
"dtype": "audio",
|
||||
"shape": (len(parameters[1]),),
|
||||
"names": ["channels"],
|
||||
"info": {"sample_rate": parameters[0]},
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
@@ -691,6 +722,8 @@ def build_dataset_frame(
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
elif ft["dtype"] == "audio":
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
|
||||
|
||||
return frame
|
||||
|
||||
@@ -724,6 +757,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif ft["dtype"] == "audio":
|
||||
type = FeatureType.AUDIO
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
@@ -802,6 +839,7 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
@@ -811,6 +849,10 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): The maximum number of files per chunk directory.
|
||||
data_files_size_in_mb (int | None): The maximum size for data files in MB.
|
||||
video_files_size_in_mb (int | None): The maximum size for video files in MB.
|
||||
audio_files_size_in_mb (int | None): The maximum size for audio files in MB.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
@@ -824,10 +866,12 @@ def create_empty_dataset_info(
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"audio_path": DEFAULT_AUDIO_PATH,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -1051,6 +1095,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "audio":
|
||||
return validate_feature_audio(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
@@ -1117,6 +1163,23 @@ def validate_feature_image_or_video(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
|
||||
-1
|
||||
]: # The number of frames might be different
|
||||
error_message += (
|
||||
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
|
||||
)
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
@@ -1172,12 +1235,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -59,6 +59,8 @@ from requests import HTTPError
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -79,7 +81,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -311,12 +313,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -352,7 +354,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -367,8 +369,124 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def get_audio_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
audio_keys = [key for key, ft in features.items() if ft["dtype"] == "audio"]
|
||||
return audio_keys
|
||||
|
||||
|
||||
def convert_audios(root: Path, new_root: Path, audio_file_size_in_mb: int):
|
||||
logging.info(f"Converting audios from {root} to {new_root}")
|
||||
|
||||
audio_keys = get_audio_keys(root)
|
||||
if len(audio_keys) == 0:
|
||||
return None
|
||||
|
||||
audio_keys = sorted(audio_keys)
|
||||
|
||||
eps_metadata_per_mic = []
|
||||
for microphone in audio_keys:
|
||||
eps_metadata = convert_audios_of_microphone(root, new_root, microphone, audio_file_size_in_mb)
|
||||
eps_metadata_per_mic.append(eps_metadata)
|
||||
|
||||
num_eps_per_mic = [len(eps_mic_map) for eps_mic_map in eps_metadata_per_mic]
|
||||
if len(set(num_eps_per_mic)) != 1:
|
||||
raise ValueError(f"All microphones dont have same number of episodes ({num_eps_per_mic}).")
|
||||
|
||||
episodes_metadata = []
|
||||
num_microphones = len(audio_keys)
|
||||
num_episodes = num_eps_per_mic[0]
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert audios"):
|
||||
# Sanity check
|
||||
ep_ids = [
|
||||
eps_metadata_per_mic[mic_idx][ep_idx]["episode_index"] for mic_idx in range(num_microphones)
|
||||
]
|
||||
ep_ids += [ep_idx]
|
||||
if len(set(ep_ids)) != 1:
|
||||
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||
|
||||
ep_dict = {}
|
||||
for mic_idx in range(num_microphones):
|
||||
ep_dict.update(eps_metadata_per_mic[mic_idx][ep_idx])
|
||||
episodes_metadata.append(ep_dict)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def convert_audios_of_microphone(root: Path, new_root: Path, audio_key: str, audio_file_size_in_mb: int):
|
||||
# Access old paths to m4a
|
||||
audios_dir = root / "audio"
|
||||
ep_paths = sorted(audios_dir.glob(f"*/{audio_key}/*.m4a"))
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert audios of {audio_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= audio_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the file we just saved
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
|
||||
|
||||
# Move to next file and start fresh with current episode
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
|
||||
# Add current episode metadata
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx, # Will be updated when file is saved
|
||||
f"audio/{audio_key}/file_index": file_idx, # Will be updated when file is saved
|
||||
f"audio/{audio_key}/from_timestamp": duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
episodes_metadata.append(ep_metadata)
|
||||
|
||||
# Add current episode to accumulation
|
||||
paths_to_cat.append(ep_path)
|
||||
size_in_mb += ep_size_in_mb
|
||||
duration_in_s += ep_duration_in_s
|
||||
ep_idx += 1
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the final file
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None, episodes_audios=None
|
||||
):
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
@@ -392,16 +510,30 @@ def generate_episode_metadata_dict(
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if episodes_audios is None:
|
||||
ep_audio = {}
|
||||
else:
|
||||
ep_audio = episodes_audios[i]
|
||||
ep_ids_set.add(ep_audio["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict = {
|
||||
**ep_metadata,
|
||||
**ep_video,
|
||||
**ep_audio,
|
||||
**ep_legacy_metadata,
|
||||
**flatten_dict({"stats": ep_stats}),
|
||||
}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
ep_dict["meta/episodes/file_index"] = 0
|
||||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
def convert_episodes_metadata(
|
||||
root, new_root, episodes_metadata, episodes_video_metadata=None, episodes_audio_metadata=None
|
||||
):
|
||||
logging.info(f"Converting episodes metadata from {root} to {new_root}")
|
||||
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
@@ -410,13 +542,19 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
if episodes_audio_metadata is not None:
|
||||
num_eps_set.add(len(episodes_audio_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
episodes_legacy_metadata,
|
||||
episodes_metadata,
|
||||
episodes_stats,
|
||||
episodes_video_metadata,
|
||||
episodes_audio_metadata,
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
@@ -425,20 +563,22 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
write_stats(stats, new_root)
|
||||
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb, audio_file_size_in_mb):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = data_file_size_in_mb
|
||||
info["video_files_size_in_mb"] = video_file_size_in_mb
|
||||
info["audio_files_size_in_mb"] = audio_file_size_in_mb
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
|
||||
info["audio_path"] = DEFAULT_AUDIO_PATH if info["audio_path"] is not None else None
|
||||
info["fps"] = int(info["fps"])
|
||||
logging.info(f"Converting info from {root} to {new_root}")
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
if info["features"][key]["dtype"] == "video" or info["features"][key]["dtype"] == "audio":
|
||||
# already has fps in video_info or audio_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
write_info(info, new_root)
|
||||
@@ -449,6 +589,7 @@ def convert_dataset(
|
||||
branch: str | None = None,
|
||||
data_file_size_in_mb: int | None = None,
|
||||
video_file_size_in_mb: int | None = None,
|
||||
audio_file_size_in_mb: int | None = None,
|
||||
root: str | Path | None = None,
|
||||
push_to_hub: bool = True,
|
||||
force_conversion: bool = False,
|
||||
@@ -457,6 +598,8 @@ def convert_dataset(
|
||||
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_file_size_in_mb is None:
|
||||
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_file_size_in_mb is None:
|
||||
audio_file_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
|
||||
# First check if the dataset already has a v3.0 version
|
||||
if root is None and not force_conversion:
|
||||
@@ -498,7 +641,10 @@ def convert_dataset(
|
||||
convert_tasks(root, new_root)
|
||||
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
|
||||
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
|
||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||
episodes_audios_metadata = convert_audios(root, new_root, audio_file_size_in_mb)
|
||||
convert_episodes_metadata(
|
||||
root, new_root, episodes_metadata, episodes_videos_metadata, episodes_audios_metadata
|
||||
)
|
||||
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
@@ -511,7 +657,7 @@ def convert_dataset(
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*", "audio/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
@@ -549,6 +695,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-file-size-in-mb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for audio.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
|
||||
@@ -397,42 +397,42 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
def concatenate_media_files(
|
||||
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
Concatenate multiple media files (video & audio) into a single media file using pyav.
|
||||
|
||||
This function takes a list of video input file paths and concatenates them into a single
|
||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
This function takes a list of input media file paths and concatenates them into a single
|
||||
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
concatenation without re-encoding.
|
||||
|
||||
Args:
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
input_media_paths: Ordered list of input media file paths to concatenate.
|
||||
output_media_path: Path to the output media file.
|
||||
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
||||
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
|
||||
codec, resolution, and frame rate for proper concatenation.
|
||||
"""
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
output_media_path = Path(output_media_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
if output_media_path.exists() and not overwrite:
|
||||
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_media_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
if len(input_media_paths) == 0:
|
||||
raise FileNotFoundError("No input media paths provided.")
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
# Create a temporary .ffconcat file to list the input media paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
for input_path in input_media_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
@@ -442,11 +442,12 @@ def concatenate_video_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
|
||||
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
|
||||
tmp_output_media_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
# Replicate input streams in output container
|
||||
@@ -461,6 +462,7 @@ def concatenate_video_files(
|
||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
last_dts = None
|
||||
for packet in input_container.demux():
|
||||
# Skip packets from un-mapped streams
|
||||
if packet.stream.index not in stream_map:
|
||||
@@ -469,6 +471,16 @@ def concatenate_video_files(
|
||||
# Skip demux flushing packets
|
||||
if packet.dts is None:
|
||||
continue
|
||||
else:
|
||||
# Enforce strictly increasing decoding timestamps (DTS)
|
||||
if last_dts is not None and packet.dts <= last_dts:
|
||||
shift = last_dts - packet.dts + 1
|
||||
packet.dts += shift
|
||||
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
|
||||
logging.warning(
|
||||
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
|
||||
)
|
||||
last_dts = packet.dts
|
||||
|
||||
output_stream = stream_map[packet.stream.index]
|
||||
packet.stream = output_stream
|
||||
@@ -476,7 +488,7 @@ def concatenate_video_files(
|
||||
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
shutil.move(tmp_output_media_path, output_media_path)
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
@@ -512,38 +524,6 @@ with warnings.catch_warnings():
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
@@ -573,9 +553,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -590,22 +567,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
Get the duration of a media file (video & audio) in seconds using PyAV.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
media_path: Path to the media file.
|
||||
|
||||
Returns:
|
||||
Duration of the video in seconds.
|
||||
Duration of the media file in seconds.
|
||||
"""
|
||||
with av.open(str(video_path)) as container:
|
||||
# Get the first video stream
|
||||
video_stream = container.streams.video[0]
|
||||
with av.open(str(media_path)) as container:
|
||||
# Get the first stream
|
||||
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
|
||||
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
||||
if video_stream.duration is not None:
|
||||
duration = float(video_stream.duration * video_stream.time_base)
|
||||
if stream.duration is not None:
|
||||
duration = float(stream.duration * stream.time_base)
|
||||
else:
|
||||
# Fallback to container duration if stream duration is not available
|
||||
duration = float(container.duration / av.time_base)
|
||||
@@ -614,12 +591,12 @@ def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
|
||||
class VideoEncodingManager:
|
||||
"""
|
||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
|
||||
|
||||
This manager handles:
|
||||
- Batch encoding for any remaining episodes when recording interrupted
|
||||
- Cleaning up temporary image files from interrupted episodes
|
||||
- Removing empty image directories
|
||||
- Cleaning up temporary image and audio files from interrupted episodes
|
||||
- Removing empty image and audio directories
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset instance
|
||||
@@ -646,6 +623,7 @@ class VideoEncodingManager:
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
self.dataset._batch_save_episode_audio(start_ep, end_ep)
|
||||
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
@@ -662,6 +640,15 @@ class VideoEncodingManager:
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
for key in self.dataset.meta.audio_keys:
|
||||
audio_file = self.dataset._get_raw_audio_file_path(
|
||||
episode_index=interrupted_episode_index, audio_key=key
|
||||
)
|
||||
if audio_file.exists():
|
||||
logging.debug(
|
||||
f"Cleaning up interrupted episode audio for episode {interrupted_episode_index}, microphone {key}"
|
||||
)
|
||||
audio_file.unlink()
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
@@ -675,4 +662,16 @@ class VideoEncodingManager:
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
# Clean up any remaining audio directory if it's empty
|
||||
audio_dir = self.dataset.root / "raw_audio"
|
||||
# Check for any remaining WAV files
|
||||
wav_files = list(audio_dir.rglob("*.wav"))
|
||||
if len(wav_files) == 0:
|
||||
# Only remove the raw_audio directory if no WAV files remain
|
||||
if audio_dir.exists():
|
||||
shutil.rmtree(audio_dir)
|
||||
logging.debug("Cleaned up empty audio directory")
|
||||
else:
|
||||
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -293,9 +293,9 @@ class LiberoEnv(gym.Env):
|
||||
def reset(self, seed=None, **kwargs):
|
||||
super().reset(seed=seed)
|
||||
self._env.seed(seed)
|
||||
if self.init_states and self._init_states is not None:
|
||||
self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
from .utils import make_microphones_from_configs
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_anyskin import AnyskinSensorConfig
|
||||
from .sensor_anyskin import AnyskinSensor
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("anyskin")
|
||||
@dataclass
|
||||
class AnyskinSensorConfig(MicrophoneConfig):
|
||||
"""Configuration class for Anyskin tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||
|
||||
This class provides configuration options for Anyskin tactile sensors, including serial port, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
AnyskinSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||
AnyskinSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||
```
|
||||
|
||||
Attributes:
|
||||
sensor_port: Serial port of the tactile sensor.
|
||||
baud_rate: Baud rate of the tactile sensor.
|
||||
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||
channels: List of channel numbers to use for the tactile sensor.
|
||||
"""
|
||||
|
||||
sensor_port: str
|
||||
baud_rate: int = 115_200
|
||||
sensor_id: int = 0
|
||||
burst_mode: bool = True
|
||||
temp_filtered: bool = False
|
||||
@@ -0,0 +1,473 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the AnyskinSensor class for capturing tactile data from Anyskin tactile sensors.
|
||||
"""
|
||||
|
||||
from doctest import master
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.hub import T
|
||||
import numpy as np
|
||||
from serial import Serial, serialutil
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_anyskin import AnyskinSensorConfig
|
||||
|
||||
from anyskin import AnySkinBase, AnySkinDummy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MAGNETS_CHANNELS = 5
|
||||
|
||||
class AnyskinSensor(Microphone):
|
||||
"""
|
||||
The AnyskinSensor class handles all Anyskin tactile sensors.
|
||||
|
||||
A AnyskinSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import AnyskinSensorConfig
|
||||
|
||||
config = AnyskinSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||
microphone = AnyskinSensor(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: AnyskinSensorConfig):
|
||||
""" "
|
||||
Initializes the AnyskinSensor instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the sensor.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Sensor port
|
||||
self.sensor_port = config.sensor_port
|
||||
|
||||
# Baud rate
|
||||
self.baud_rate = config.baud_rate
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the sensor is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the sensor is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is recording, False otherwise.
|
||||
"""
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the sensor is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is writing to a file, False otherwise.
|
||||
"""
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available sensors connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected sensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the sensor.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("int16"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the busy_wait function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.sensor_port,
|
||||
self.baud_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
sensor_port,
|
||||
baud_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
) -> None:
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def tactile_callback(tactile_sensor: AnySkinBase):
|
||||
"""
|
||||
Parse the tactile data from the raw input data.
|
||||
"""
|
||||
if audio_callback_start_event.is_set():
|
||||
timestamp, indata = tactile_sensor.get_sample()
|
||||
indata = indata.reshape(-1, MAX_MAGNETS_CHANNELS)
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
try:
|
||||
tactile_sensor = AnySkinBase(
|
||||
num_mags=MAX_MAGNETS_CHANNELS,
|
||||
port=sensor_port,
|
||||
baudrate=baud_rate,
|
||||
burst_mode=True,
|
||||
device_id=0, #TODO(CarolinePascal): create an abstract increasing id for each sensor
|
||||
temp_filtered=False,
|
||||
) #TODO(CarolinePascal): add timeout on serial connection ?
|
||||
except (serialutil.SerialException, AttributeError) as e:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {sensor_port}: {e}")
|
||||
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
record_is_started_event.set()
|
||||
while not record_stop_event.is_set():
|
||||
tactile_callback(tactile_sensor) # Initial flush is already done in the constructor.
|
||||
record_is_started_event.clear()
|
||||
tactile_sensor.close() # Closes the inherited serial connection.
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect the sensor and release any resources.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start recording tactile data from the sensor.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded tactile data.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=AnyskinSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=AnyskinSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the sensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
tactile_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return tactile_readings
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""Internal loop run by the background thread for asynchronous reading."""
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the sensor."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="FLOAT", # Subtype for float32 values
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MicrophoneConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
sample_rate: int | None = None
|
||||
channels: list[int] | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from threading import Barrier
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
|
||||
|
||||
class Microphone(abc.ABC):
|
||||
"""Base class for microphone implementations.
|
||||
|
||||
Defines a standard interface for microphone operations across different backends.
|
||||
Subclasses must implement all abstract methods.
|
||||
|
||||
Manages basic microphone properties (sample rate, channels) and core operations:
|
||||
- Connection/disconnection
|
||||
- Start/stop recording
|
||||
- Audio chunk reading
|
||||
|
||||
Attributes:
|
||||
sample_rate (int | None): Configured sample rate in Hz
|
||||
channels (list[int] | None): List of channel numbers to record
|
||||
|
||||
Example:
|
||||
class MyMicrophone(Microphone):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: MicrophoneConfig):
|
||||
"""Initialize the microphone with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Microphone configuration containing sample rate and channels.
|
||||
"""
|
||||
self.sample_rate: int | None = config.sample_rate
|
||||
self.channels: list[int] | None = config.channels
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the microphone is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the microphone is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is recording, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the microphone is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is writing to a file, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available microphones connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> None:
|
||||
"""Establish connection to the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""Start recording audio from the microphone.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded audio.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple microphones start recording at the same time.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the microphone.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the microphone and release any resources."""
|
||||
pass
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
from .microphone_portaudio import PortAudioMicrophone
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("portaudio")
|
||||
@dataclass
|
||||
class PortAudioMicrophoneConfig(MicrophoneConfig):
|
||||
"""Configuration class for PortAudio-based microphone devices.
|
||||
|
||||
This class provides configuration options for microphones accessed through PortAudio with the sounddevice Python package.
|
||||
including device index, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
PortAudioMicrophoneConfig(0, 16000, [1]) # Device index 0, 16000Hz, mono
|
||||
PortAudioMicrophoneConfig(1, 44100, [1, 2]) # Device index 1, 44100Hz, stereo
|
||||
```
|
||||
|
||||
Attributes:
|
||||
microphone_index: Device index for the microphone.
|
||||
sample_rate: Sample rate in Hz for the microphone.
|
||||
channels: List of channel numbers to use for the microphone.
|
||||
"""
|
||||
|
||||
microphone_index: int
|
||||
@@ -0,0 +1,394 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sounddevice import PortAudioError
|
||||
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
|
||||
# --- Interface definitions for InputStream ---
|
||||
class IInputStream(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | str | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | np.dtype | None = None,
|
||||
latency: float | str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ISounddeviceSDK(abc.ABC):
|
||||
"""Interface defining the contract for the Sounddevice SDK."""
|
||||
|
||||
InputStream: type[IInputStream]
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
|
||||
# --- Real SDK Adapter ---
|
||||
|
||||
|
||||
class SounddeviceSDKAdapter(ISounddeviceSDK):
|
||||
"""Adapts the real sounddevice library to the ISounddeviceSDK interface."""
|
||||
|
||||
_sounddevice = None
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
import sounddevice
|
||||
|
||||
SounddeviceSDKAdapter._sounddevice = sounddevice
|
||||
except ImportError as e:
|
||||
raise ImportError("sounddevice library not found") from e
|
||||
|
||||
# --- Inner Class Implementation ---
|
||||
class RealInputStream(IInputStream):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | np.dtype | None = None,
|
||||
latency: float | str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
import sounddevice
|
||||
|
||||
self._input_stream = sounddevice.InputStream(
|
||||
samplerate=samplerate,
|
||||
blocksize=blocksize,
|
||||
device=device,
|
||||
channels=channels,
|
||||
dtype=dtype,
|
||||
latency=latency,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
self._input_stream.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._input_stream.stop()
|
||||
|
||||
def close(self) -> None:
|
||||
self._input_stream.close()
|
||||
|
||||
def __del__(self):
|
||||
self._input_stream.stop()
|
||||
self._input_stream.close()
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self._input_stream.active
|
||||
|
||||
@property
|
||||
def stopped(self) -> bool:
|
||||
return self._input_stream.stopped
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._input_stream.closed
|
||||
|
||||
InputStream = RealInputStream
|
||||
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
return SounddeviceSDKAdapter._sounddevice.query_devices(device, kind)
|
||||
|
||||
|
||||
# Emulates a 48kHz stereo microphone
|
||||
VALID_DTYPE = {
|
||||
"float32",
|
||||
"int32",
|
||||
"int16",
|
||||
"int8",
|
||||
"uint8",
|
||||
np.float32,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
}
|
||||
VALID_LATENCY = {"low", "high"}
|
||||
|
||||
VALID_DEVICES = [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "Built-in Microphone",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.001,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.01,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "Built-in Output",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 0,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.04,
|
||||
"default_low_output_latency": 0.04,
|
||||
"default_high_input_latency": 0.12,
|
||||
"default_high_output_latency": 0.12,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "USB Audio Device",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 1,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.03,
|
||||
"default_low_output_latency": 0.01,
|
||||
"default_high_input_latency": 0.04,
|
||||
"default_high_output_latency": 0.03,
|
||||
"default_samplerate": 16000.0,
|
||||
},
|
||||
]
|
||||
|
||||
# -- Fake SDK Adapter ---
|
||||
|
||||
|
||||
class FakeSounddeviceSDKAdapter(ISounddeviceSDK):
|
||||
"""Implements the ISounddeviceSDK interface with fake behaviour for testing."""
|
||||
|
||||
# --- Inner Class Implementation ---
|
||||
class FakeInputStream(IInputStream):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | str | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | None = None,
|
||||
latency: str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.blocksize = blocksize
|
||||
self.device = device
|
||||
self.channels = channels
|
||||
self.dtype = dtype
|
||||
self.latency = latency
|
||||
self.callback = callback
|
||||
|
||||
self._validate_settings()
|
||||
|
||||
self._active = False
|
||||
self._closed = False
|
||||
|
||||
if self.callback is not None:
|
||||
self._streaming_thread = Thread(target=self._streaming_loop, daemon=True)
|
||||
self._streaming_thread_stop_event = Event()
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""True when the stream is active, False otherwise."""
|
||||
return self._active
|
||||
|
||||
@property
|
||||
def stopped(self) -> bool:
|
||||
"""True when the stream is stopped, False otherwise."""
|
||||
return not self._active
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
"""True after a call to close(), False otherwise."""
|
||||
return self._closed
|
||||
|
||||
def _get_device_info(self):
|
||||
"""Returns the device info for the device."""
|
||||
for device in VALID_DEVICES:
|
||||
if (isinstance(self.device, int) and device["index"] == self.device) or (
|
||||
isinstance(self.device, str) and device["name"] == self.device
|
||||
):
|
||||
return device
|
||||
raise PortAudioError(f"No input device matching {self.device}")
|
||||
|
||||
def _validate_device(self):
|
||||
"""Validates the device against the valid devices."""
|
||||
valid_device_indices = [device["index"] for device in VALID_DEVICES]
|
||||
valid_device_names = [device["name"] for device in VALID_DEVICES]
|
||||
|
||||
if self.device is not None:
|
||||
if isinstance(self.device, (int, str)):
|
||||
# Check if device index is valid
|
||||
if isinstance(self.device, int) and self.device not in valid_device_indices:
|
||||
raise PortAudioError(f"Error querying device {self.device}")
|
||||
|
||||
# Check if device name is valid
|
||||
if isinstance(self.device, str) and self.device not in valid_device_names:
|
||||
raise PortAudioError(f"No input device matching {self.device}")
|
||||
else:
|
||||
raise PortAudioError(f"Device must be int or str, got {type(self.device)}")
|
||||
else:
|
||||
# Default to first input device
|
||||
input_devices = [d for d in VALID_DEVICES if d["max_input_channels"] > 0]
|
||||
if input_devices:
|
||||
self.device = input_devices[0]["index"]
|
||||
|
||||
def _validate_samplerate(self):
|
||||
"""Validates the samplerate against the device's maximum samplerate."""
|
||||
device_info = self._get_device_info()
|
||||
if self.samplerate is None:
|
||||
self.samplerate = device_info["default_samplerate"]
|
||||
elif self.samplerate > device_info["default_samplerate"] or self.samplerate < 1000:
|
||||
raise PortAudioError("Error opening InputStream: Invalid sample rate")
|
||||
|
||||
def _validate_channels(self):
|
||||
"""Validates the channels against the device's maximum channels."""
|
||||
device_info = self._get_device_info()
|
||||
if self.channels is None:
|
||||
self.channels = device_info["max_input_channels"]
|
||||
elif self.channels > device_info["max_input_channels"] or self.channels < 1:
|
||||
raise PortAudioError("Error opening InputStream: Invalid number of channels")
|
||||
|
||||
def _validate_dtype(self):
|
||||
"""Validates the dtype against the valid dtypes."""
|
||||
if self.dtype is not None:
|
||||
if self.dtype not in VALID_DTYPE:
|
||||
raise PortAudioError("Invalid input sample format")
|
||||
else:
|
||||
self.dtype = "float32" # Default dtype
|
||||
|
||||
def _validate_latency(self):
|
||||
"""Validates the latency against the valid latencies."""
|
||||
if self.latency is not None:
|
||||
if self.latency not in VALID_LATENCY:
|
||||
raise PortAudioError("Invalid latency")
|
||||
else:
|
||||
self.latency = "low" # Default latency
|
||||
|
||||
if isinstance(self.latency, str):
|
||||
device_info = self._get_device_info()
|
||||
if self.latency == "low":
|
||||
self.latency = device_info["default_low_input_latency"]
|
||||
elif self.latency == "high":
|
||||
self.latency = device_info["default_high_input_latency"]
|
||||
|
||||
def _validate_settings(self):
|
||||
"""Validates the input parameters against available devices and valid options."""
|
||||
self._validate_device()
|
||||
self._validate_samplerate()
|
||||
self._validate_channels()
|
||||
self._validate_dtype()
|
||||
self._validate_latency()
|
||||
|
||||
def _simulated_audio_data(self) -> np.ndarray:
|
||||
"""Generates a simulated audio signal for testing purposes with proper value ranges."""
|
||||
duration_samples = int(self.samplerate * self.latency)
|
||||
|
||||
# Generate output according to dtype
|
||||
if self.dtype in {"float32", np.float32}:
|
||||
# Generate values between -1 and 1 for float32
|
||||
data = np.random.uniform(-1.0, 1.0, (duration_samples, self.channels)).astype(self.dtype)
|
||||
else:
|
||||
# Use np.iinfo to get proper range for integer types
|
||||
info = np.iinfo(self.dtype)
|
||||
data = np.random.randint(
|
||||
info.min, info.max + 1, (duration_samples, self.channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _streaming_loop(self):
|
||||
if self.callback is not None:
|
||||
while not self._streaming_thread_stop_event.is_set():
|
||||
precise_sleep(self.latency)
|
||||
tmp_data = self._simulated_audio_data()
|
||||
self.callback(
|
||||
tmp_data,
|
||||
len(tmp_data),
|
||||
time.perf_counter(),
|
||||
None,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the fake input stream."""
|
||||
if not self.active and self.callback is not None:
|
||||
self._streaming_thread.start()
|
||||
self._active = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the fake input stream."""
|
||||
if self.callback is not None:
|
||||
self._streaming_thread_stop_event.set()
|
||||
self._streaming_thread.join()
|
||||
self._active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the fake input stream."""
|
||||
if self.active and self.callback is not None:
|
||||
self.stop()
|
||||
self._active = False
|
||||
self._closed = True
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
InputStream = FakeInputStream
|
||||
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Returns a realistic list of audio devices including speakers and microphones."""
|
||||
if device is not None:
|
||||
# Return specific device
|
||||
for valid_device in VALID_DEVICES:
|
||||
if (isinstance(device, int) and valid_device["index"] == device) or (
|
||||
isinstance(device, str) and valid_device["name"] == device
|
||||
):
|
||||
return valid_device
|
||||
raise PortAudioError(f"Error querying device {device}")
|
||||
|
||||
elif kind is not None:
|
||||
for valid_device in VALID_DEVICES:
|
||||
if (
|
||||
valid_device["max_input_channels"] > 0
|
||||
and kind == "input"
|
||||
or valid_device["max_output_channels"] > 0
|
||||
and kind == "output"
|
||||
):
|
||||
return valid_device
|
||||
raise PortAudioError(f"No {kind} device found")
|
||||
|
||||
return VALID_DEVICES
|
||||
@@ -0,0 +1,566 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the PortAudioMicrophone class for capturing audio from microphones using the PortAudio library through the sounddevice Python package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PortAudioMicrophone(Microphone):
|
||||
"""
|
||||
The PortAudioMicrophone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||
|
||||
A PortAudioMicrophone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import PortAudioMicrophoneConfig
|
||||
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
|
||||
microphone = PortAudioMicrophone(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: PortAudioMicrophoneConfig, sounddevice_sdk: ISounddeviceSDK = None):
|
||||
"""
|
||||
Initializes the PortAudioMicrophone instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the microphone.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if sounddevice_sdk is None:
|
||||
self.sounddevice_sdk = SounddeviceSDKAdapter()
|
||||
else:
|
||||
self.sounddevice_sdk = sounddevice_sdk
|
||||
|
||||
# Microphone index
|
||||
self.microphone_index = config.microphone_index
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.microphone_index})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones(
|
||||
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
|
||||
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||
"""
|
||||
Detects available microphones connected to the system.
|
||||
|
||||
Args:
|
||||
device: The device to find microphones for. If None, all microphones are found.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
||||
"""
|
||||
|
||||
if sounddevice_sdk is None:
|
||||
sounddevice_sdk = SounddeviceSDKAdapter()
|
||||
|
||||
found_microphones_info = []
|
||||
|
||||
devices = sounddevice_sdk.query_devices()
|
||||
for d in devices:
|
||||
if d["max_input_channels"] > 0:
|
||||
microphone_info = {
|
||||
"index": d["index"],
|
||||
"name": d["name"],
|
||||
"sample_rate": int(d["default_samplerate"]),
|
||||
"channels": np.arange(1, d["max_input_channels"] + 1),
|
||||
}
|
||||
|
||||
if device is None or (
|
||||
(isinstance(device, int) and d["index"] == device)
|
||||
or (isinstance(device, str) and d["name"] == device)
|
||||
):
|
||||
found_microphones_info.append(microphone_info)
|
||||
|
||||
if device is not None:
|
||||
if len(found_microphones_info) == 0:
|
||||
raise RuntimeError(f"No microphone found for device {device}")
|
||||
else:
|
||||
return found_microphones_info[0]
|
||||
|
||||
if len(found_microphones_info) == 0:
|
||||
logger.warning("No microphone found !")
|
||||
|
||||
return found_microphones_info
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
|
||||
|
||||
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If one of the specified settings is not compatible with the microphone.
|
||||
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"Cannot configure settings for {self} as it is already connected."
|
||||
)
|
||||
|
||||
self._validate_microphone_index()
|
||||
self._validate_sample_rate()
|
||||
self._validate_channels()
|
||||
|
||||
def _validate_microphone_index(self) -> None:
|
||||
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||
|
||||
try:
|
||||
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
|
||||
) from e
|
||||
|
||||
def _validate_sample_rate(self) -> None:
|
||||
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||
|
||||
actual_sample_rate = PortAudioMicrophone.find_microphones(
|
||||
self.microphone_index, self.sounddevice_sdk
|
||||
)["sample_rate"]
|
||||
|
||||
if self.sample_rate is not None:
|
||||
try:
|
||||
self.sample_rate = int(self.sample_rate)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
|
||||
) from e
|
||||
|
||||
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||
raise RuntimeError(
|
||||
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||
)
|
||||
else:
|
||||
if self.sample_rate < actual_sample_rate:
|
||||
logger.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = actual_sample_rate
|
||||
|
||||
def _validate_channels(self) -> None:
|
||||
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||
|
||||
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
|
||||
"channels"
|
||||
]
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if not all(channel in actual_channels for channel in self.channels):
|
||||
raise RuntimeError(
|
||||
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
|
||||
)
|
||||
else:
|
||||
self.channels = actual_channels
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("float32"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.microphone_index,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
self.sounddevice_sdk,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting microphone {self.microphone_index}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects the microphone and stops the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting microphone {self.microphone_index}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""
|
||||
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
audio_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return audio_readings
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
microphone_index,
|
||||
sample_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
sounddevice_sdk,
|
||||
) -> None:
|
||||
"""
|
||||
Process callback used to create an unpickable sounddevice audio input stream with a recording callback and start, stop and close it based on multiprocessing events.
|
||||
"""
|
||||
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def audio_callback(indata, frames, timestamp, status) -> None:
|
||||
"""
|
||||
Low-level sounddevice callback.
|
||||
"""
|
||||
if status:
|
||||
logger.warning(status)
|
||||
if audio_callback_start_event.is_set():
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
# Create the audio stream
|
||||
# InputStream must be instantiated in the process as it is not pickable.
|
||||
stream = sounddevice_sdk.InputStream(
|
||||
device=microphone_index,
|
||||
samplerate=sample_rate,
|
||||
channels=max(channels),
|
||||
dtype="float32",
|
||||
blocksize=0, # Varying input buffer length, but no additional latency
|
||||
latency="low", # Low latency mode (not enabled by default !)
|
||||
# never_drop_input=True, # Disabled as it generates an error for some devices
|
||||
callback=audio_callback,
|
||||
)
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
stream.start()
|
||||
record_is_started_event.set()
|
||||
record_stop_event.wait()
|
||||
stream.stop() # stream.stop() waits for all buffers to be processed, stream.abort() flushes the buffers !
|
||||
record_is_started_event.clear()
|
||||
stream.close()
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=PortAudioMicrophone._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=PortAudioMicrophone._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for microphone {self.microphone_index}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for microphone {self.microphone_index}.")
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""
|
||||
Stops the recording of the microphones.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
# Wait for the stream to be stopped (might lead to race condition if the stream is not properly stopped on array reset and queue clearing)
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for microphone {self.microphone_index}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for microphone {self.microphone_index}.")
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="FLOAT", # By default, a much lower quality WAV file is created !
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_touchlab import TouchLabSensorConfig
|
||||
from .sensor_touchlab import TouchLabSensor
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("touchlab")
|
||||
@dataclass
|
||||
class TouchLabSensorConfig(MicrophoneConfig):
|
||||
"""Configuration class for TouchLab tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||
|
||||
This class provides configuration options for TouchLab tactile sensors, including serial port, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
TouchLabSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||
TouchLabSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||
```
|
||||
|
||||
Attributes:
|
||||
sensor_port: Serial port of the tactile sensor.
|
||||
baud_rate: Baud rate of the tactile sensor.
|
||||
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||
channels: List of channel numbers to use for the tactile sensor.
|
||||
"""
|
||||
|
||||
sensor_port: str
|
||||
baud_rate: int = 115_200
|
||||
@@ -0,0 +1,469 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the TouchLabSensor class for capturing tactile data from TouchLab tactile sensors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from serial import Serial
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_touchlab import TouchLabSensorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_SERIAL_READ_SIZE = 512
|
||||
|
||||
|
||||
class TouchLabSensor(Microphone):
|
||||
"""
|
||||
The TouchLabSensor class handles all TouchLab tactile sensors.
|
||||
|
||||
A TouchLabSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import TouchLabSensorConfig
|
||||
|
||||
config = TouchLabSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||
microphone = TouchLabSensor(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: TouchLabSensorConfig):
|
||||
""" "
|
||||
Initializes the TouchLabSensor instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the sensor.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Sensor port
|
||||
self.sensor_port = config.sensor_port
|
||||
|
||||
# Baud rate
|
||||
self.baud_rate = config.baud_rate
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the sensor is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the sensor is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is recording, False otherwise.
|
||||
"""
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the sensor is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is writing to a file, False otherwise.
|
||||
"""
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available sensors connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected sensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the sensor.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("int16"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.sensor_port,
|
||||
self.baud_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
sensor_port,
|
||||
baud_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
) -> None:
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def tactile_callback(serial_connection):
|
||||
"""
|
||||
Parse the tactile data from the raw input data.
|
||||
"""
|
||||
buffer = serial_connection.readline()
|
||||
|
||||
if audio_callback_start_event.is_set():
|
||||
strings = buffer.decode("utf8").split(",")
|
||||
num_taxels = len(strings)
|
||||
|
||||
if num_taxels > 0 and num_taxels < MAX_SERIAL_READ_SIZE: # Make sure we didn't read rubbish
|
||||
indata = np.empty((1, num_taxels))
|
||||
for i in range(num_taxels):
|
||||
indata[0, i] = int(strings[i])
|
||||
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
|
||||
with Serial(sensor_port, baud_rate, timeout=0.5) as serial_connection:
|
||||
serial_connection.flush()
|
||||
record_is_started_event.set()
|
||||
while not record_stop_event.is_set():
|
||||
tactile_callback(serial_connection)
|
||||
record_is_started_event.clear()
|
||||
serial_connection.close()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect the sensor and release any resources.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start recording tactile data from the sensor.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded tactile data.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=TouchLabSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=TouchLabSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the sensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
tactile_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return tactile_readings
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""Internal loop run by the background thread for asynchronous reading."""
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the sensor."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="PCM_16", # Subtype for int16 values
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from multiprocessing import Barrier
|
||||
from threading import Thread
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
|
||||
|
||||
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig]) -> dict[str, Microphone]:
|
||||
microphones = {}
|
||||
|
||||
for key, cfg in microphone_configs.items():
|
||||
if cfg.type == "portaudio":
|
||||
from .portaudio import PortAudioMicrophone
|
||||
|
||||
microphones[key] = PortAudioMicrophone(cfg)
|
||||
elif cfg.type == "touchlab":
|
||||
from .touchlab import TouchLabSensor
|
||||
|
||||
microphones[key] = TouchLabSensor(cfg)
|
||||
elif cfg.type == "anyskin":
|
||||
from .anyskin import AnyskinSensor
|
||||
|
||||
microphones[key] = AnyskinSensor(cfg)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||
|
||||
return microphones
|
||||
|
||||
|
||||
def async_microphones_start_recording(
|
||||
microphones: dict[str, Microphone],
|
||||
output_files: list[str | None] | None = None,
|
||||
multiprocessing: bool = False,
|
||||
overwrite: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Starts recording on multiple microphones asynchronously to avoid delays.
|
||||
|
||||
Args:
|
||||
microphones: A dictionary of microphones.
|
||||
output_files: A list of output files.
|
||||
multiprocessing: If True, enables multiprocessing for recording.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
"""
|
||||
|
||||
start_recording_threads = []
|
||||
if output_files is None:
|
||||
output_files = [None] * len(microphones)
|
||||
|
||||
barrier = Barrier(len(microphones))
|
||||
|
||||
for microphone, output_file in zip(microphones.values(), output_files, strict=False):
|
||||
start_recording_threads.append(
|
||||
Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite, barrier))
|
||||
)
|
||||
|
||||
for thread in start_recording_threads:
|
||||
thread.start()
|
||||
for thread in start_recording_threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def async_microphones_stop_recording(microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Stops recording on multiple microphones asynchronously to avoid delays.
|
||||
|
||||
Args:
|
||||
microphones: A dictionary of microphones.
|
||||
"""
|
||||
|
||||
stop_recording_threads = []
|
||||
|
||||
for microphone in microphones.values():
|
||||
stop_recording_threads.append(Thread(target=microphone.stop_recording))
|
||||
|
||||
for thread in stop_recording_threads:
|
||||
thread.start()
|
||||
for thread in stop_recording_threads:
|
||||
thread.join()
|
||||
@@ -205,6 +205,7 @@ MODEL_BAUDRATE_TABLE = {
|
||||
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Present_Load": 10,
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Position": 15,
|
||||
"Goal_Velocity": 15,
|
||||
|
||||
@@ -32,7 +32,7 @@ import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
@@ -411,6 +411,7 @@ class MotorsBus(abc.ABC):
|
||||
"""bool: `True` if the underlying serial port is open."""
|
||||
return self.port_handler.is_open
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Open the serial port and initialise communication.
|
||||
|
||||
@@ -422,10 +423,6 @@ class MotorsBus(abc.ABC):
|
||||
DeviceAlreadyConnectedError: The port is already open.
|
||||
ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||
)
|
||||
|
||||
self._connect(handshake)
|
||||
self.set_timeout()
|
||||
@@ -447,6 +444,7 @@ class MotorsBus(abc.ABC):
|
||||
def _handshake(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Close the serial port (optionally disabling torque first).
|
||||
|
||||
@@ -455,10 +453,6 @@ class MotorsBus(abc.ABC):
|
||||
closing the port. This can prevent damaging motors if they are left applying resisting torque
|
||||
after disconnect.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
self.port_handler.clearPort()
|
||||
@@ -907,6 +901,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -927,10 +922,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
Value: Raw or normalised value depending on *normalize*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -981,6 +972,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return value, comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
@@ -999,10 +991,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): Enable or disable normalisation. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -1044,6 +1032,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1063,10 +1052,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
dict[str, Value]: Mapping *motor name → value*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
self._assert_protocol_is_compatible("sync_read")
|
||||
|
||||
@@ -1139,6 +1124,7 @@ class MotorsBus(abc.ABC):
|
||||
# for id_ in motor_ids:
|
||||
# value = self.sync_reader.getData(id_, address, length)
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1160,10 +1146,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): If `True` (default) convert values from the user range to raw units.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
|
||||
@@ -98,6 +98,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"AUDIO": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
@@ -108,6 +109,10 @@ class ACTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
replace_final_stride_with_dilation: int = False
|
||||
# Audio backbone.
|
||||
audio_backbone: str = vision_backbone
|
||||
pretrained_backbone_weights_audio: str | None = None
|
||||
replace_final_stride_with_dilation_audio: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
dim_model: int = 512
|
||||
@@ -170,8 +175,10 @@ class ACTConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
if not (self.image_features or self.audio_features) and not self.env_state_feature:
|
||||
raise ValueError(
|
||||
"You must provide at least one image/audio or the environment state among the inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
||||
@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -106,6 +106,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
@@ -331,12 +333,26 @@ class ACT(nn.Module):
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Backbone for audio feature extraction.
|
||||
if self.config.audio_features:
|
||||
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
|
||||
weights=config.pretrained_backbone_weights_audio,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.audio_backbone = IntermediateLayerGetter(
|
||||
audio_backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)].
|
||||
if self.config.robot_state_feature:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
@@ -350,6 +366,10 @@ class ACT(nn.Module):
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
if self.config.audio_features:
|
||||
self.encoder_audio_feat_input_proj = nn.Conv2d(
|
||||
audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.config.robot_state_feature:
|
||||
@@ -359,6 +379,8 @@ class ACT(nn.Module):
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
if self.config.audio_features:
|
||||
self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
@@ -483,6 +505,21 @@ class ACT(nn.Module):
|
||||
encoder_in_tokens.extend(list(cam_features))
|
||||
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
||||
|
||||
if self.config.audio_features:
|
||||
for audio in batch[OBS_AUDIO]:
|
||||
audio_features = self.audio_backbone(audio)["feature_map"]
|
||||
audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to(
|
||||
dtype=audio_features.dtype
|
||||
)
|
||||
audio_features = self.encoder_audio_feat_input_proj(audio_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c")
|
||||
audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
encoder_in_tokens.extend(list(audio_features))
|
||||
encoder_in_pos_embed.extend(list(audio_pos_embed))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
||||
|
||||
@@ -17,9 +17,11 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
AudioProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -63,6 +65,15 @@ def make_act_pre_post_processors(
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
AudioProcessorStep(
|
||||
output_height=224,
|
||||
output_width=224,
|
||||
output_channels=3,
|
||||
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
input_sample_rate=48000,
|
||||
intermediate_sample_rate=16000,
|
||||
n_fft=1024,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
|
||||
@@ -32,16 +32,22 @@ Notes:
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
@@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: GrootConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = True,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
|
||||
config: Optional GrootConfig. If None, loads from checkpoint or creates default
|
||||
force_download: Force download even if cached
|
||||
resume_download: Resume interrupted download
|
||||
proxies: Proxy settings
|
||||
token: HuggingFace authentication token
|
||||
cache_dir: Cache directory path
|
||||
local_files_only: Only use local files
|
||||
revision: Specific model revision
|
||||
strict: Strict state dict loading
|
||||
**kwargs: Additional arguments (passed to config)
|
||||
|
||||
Returns:
|
||||
Initialized GrootPolicy instance with loaded model
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
is_finetuned_checkpoint = False
|
||||
|
||||
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
|
||||
try:
|
||||
if os.path.isdir(model_id):
|
||||
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
|
||||
else:
|
||||
# Try to download the safetensors file to check if it exists
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=False, # Just check, don't force download
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
is_finetuned_checkpoint = True
|
||||
except HfHubHTTPError:
|
||||
is_finetuned_checkpoint = False
|
||||
except Exception:
|
||||
is_finetuned_checkpoint = False
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
# These are placeholders - actual robot features come from the preprocessor
|
||||
if not config.input_features:
|
||||
config.input_features = {
|
||||
f"{OBS_IMAGES}.camera": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Default image size from config
|
||||
),
|
||||
}
|
||||
else:
|
||||
# Override the base_model_path with the provided path
|
||||
config.base_model_path = str(pretrained_name_or_path)
|
||||
|
||||
# Pass through any additional config overrides from kwargs
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
@@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import builtins
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
@@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
def wrap_with_peft(
|
||||
self,
|
||||
peft_config=None,
|
||||
peft_cli_overrides: dict | None = None,
|
||||
) -> "PreTrainedPolicy":
|
||||
"""
|
||||
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
|
||||
|
||||
This method is the single entry point for PEFT integration. Subclasses should
|
||||
override `_get_default_peft_targets()` to provide default target modules, and
|
||||
`_validate_peft_config()` for policy-specific validation.
|
||||
|
||||
Args:
|
||||
peft_config: Optional PEFT adapter configuration (e.g., LoraConfig).
|
||||
If provided, used directly (with CLI overrides applied).
|
||||
peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.)
|
||||
These are merged with policy defaults to build the final config.
|
||||
"""
|
||||
from peft import get_peft_model
|
||||
|
||||
# If user provided a complete config, use it directly (with overrides)
|
||||
if peft_config is not None:
|
||||
final_config = peft_config
|
||||
if peft_cli_overrides:
|
||||
final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides)
|
||||
else:
|
||||
# Build config from defaults + CLI overrides
|
||||
final_config = self._build_peft_config(peft_cli_overrides or {})
|
||||
|
||||
# Validate the configuration
|
||||
self._validate_peft_config(final_config)
|
||||
|
||||
# Freeze base parameters, only adapter params will be trained
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
# Store pretrained path for PEFT's base_model_name_or_path
|
||||
if self.config.pretrained_path:
|
||||
self.name_or_path = str(self.config.pretrained_path)
|
||||
|
||||
# Wrap with PEFT
|
||||
peft_model = get_peft_model(self, final_config)
|
||||
|
||||
# Mark config as using PEFT for proper loading later
|
||||
peft_model.config.use_peft = True
|
||||
|
||||
logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})")
|
||||
return peft_model
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any] | None:
|
||||
"""
|
||||
Return default PEFT target modules for this policy.
|
||||
|
||||
Override this in subclasses to provide policy-specific defaults. These defaults
|
||||
are PEFT-method agnostic - they only specify which modules to target.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""
|
||||
Validate the PEFT configuration for this policy.
|
||||
|
||||
Override this in subclasses to add policy-specific validation or warnings.
|
||||
The default implementation checks that a pretrained_path exists.
|
||||
|
||||
Args:
|
||||
peft_config: The PEFT configuration to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid.
|
||||
"""
|
||||
if not self.config.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT is unlikely to yield good results. "
|
||||
"Supply a `policy.pretrained_path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict:
|
||||
"""
|
||||
Preprocess CLI overrides: rename keys and handle method-specific init_type.
|
||||
|
||||
Args:
|
||||
cli_overrides: Dict of CLI options (will be copied, not mutated).
|
||||
peft_method_type: The PeftType enum value for the PEFT method.
|
||||
|
||||
Returns:
|
||||
Preprocessed dict with renamed keys and init_type mapped to method-specific key.
|
||||
"""
|
||||
from peft import PeftType
|
||||
|
||||
cli_overrides = cli_overrides.copy()
|
||||
|
||||
# Handle the full_training_modules -> modules_to_save rename
|
||||
if "full_training_modules" in cli_overrides:
|
||||
cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules")
|
||||
|
||||
# Remove method_type as it's handled separately
|
||||
cli_overrides.pop("method_type", None)
|
||||
|
||||
# Handle init_type specially based on PEFT method
|
||||
init_type = cli_overrides.pop("init_type", None)
|
||||
if init_type is not None:
|
||||
if peft_method_type == PeftType.LORA:
|
||||
cli_overrides["init_lora_weights"] = init_type
|
||||
elif peft_method_type == PeftType.MISS:
|
||||
cli_overrides["init_weights"] = init_type
|
||||
else:
|
||||
raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.")
|
||||
|
||||
return cli_overrides
|
||||
|
||||
def _build_peft_config(self, cli_overrides: dict):
|
||||
"""Build a PEFT config from policy defaults and CLI overrides."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Determine PEFT method type (default to LORA)
|
||||
method_type_str = cli_overrides.get("method_type") or "lora"
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with policy defaults, apply CLI overrides
|
||||
config_dict = dict(self._get_default_peft_targets() or {})
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
# Ensure we have target_modules
|
||||
if not config_dict.get("target_modules"):
|
||||
raise ValueError(
|
||||
f"Policy '{self.name}' does not define default target_modules. "
|
||||
"Please pass --peft.target_modules explicitly."
|
||||
)
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict):
|
||||
"""Apply CLI overrides to an existing PEFT config."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Get method type from existing config or CLI override
|
||||
method_type_str = cli_overrides.get("method_type")
|
||||
if method_type_str:
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
else:
|
||||
peft_method_type = PeftType(peft_config.peft_type)
|
||||
peft_config_cls = type(peft_config)
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with existing config, apply CLI overrides
|
||||
config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")}
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
@@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for SmolVLA fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""Validate PEFT configuration for SmolVLA."""
|
||||
super()._validate_peft_config(peft_config)
|
||||
if not self.config.load_vlm_weights:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Set `load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
|
||||
This function takes a dictionary of NumPy arrays, performs necessary
|
||||
preprocessing, and prepares it for model inference. The steps include:
|
||||
1. Converting NumPy arrays to PyTorch tensors.
|
||||
2. Normalizing and permuting image data (if any).
|
||||
2. Normalizing and permuting image data and audio data (if any).
|
||||
3. Adding a batch dimension to each tensor.
|
||||
4. Moving all tensors to the specified compute device.
|
||||
5. Adding task and robot type information to the dictionary.
|
||||
@@ -129,6 +129,9 @@ def prepare_observation_for_inference(
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
elif "audio" in name:
|
||||
observation[name] = observation[name].type(torch.float32)
|
||||
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .audio_processor import AudioProcessorStep
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
@@ -80,6 +81,7 @@ __all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"AudioProcessorStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
from torchaudio.functional import amplitude_to_DB
|
||||
from torchaudio.transforms import MelSpectrogram, Resample
|
||||
from torchvision.transforms import Compose, Lambda, Resize
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.utils.constants import OBS_AUDIO
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="audio_processor")
|
||||
class AudioProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes audio waveform data into a mel-spectrogram image representation.
|
||||
|
||||
**Audio Processing:**
|
||||
- Averages waveform data over all channels.
|
||||
- Resamples the waveform to 16kHz.
|
||||
- Converts the waveform to a mel-spectrogram.
|
||||
- Converts the mel-spectrogram to decibels.
|
||||
- Resizes the mel-spectrogram to 224×224.
|
||||
- Converts the mel-spectrogram to a channel-first, normalized tensor.
|
||||
|
||||
Attributes:
|
||||
output_height: Height of the output mel-spectrogram image in pixels.
|
||||
output_width: Width of the output mel-spectrogram image in pixels.
|
||||
output_channels: Number of channels in the output image (3 for RGB-like format).
|
||||
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
|
||||
input_sample_rate: Original sample rate of the input audio in Hz.
|
||||
|
||||
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
|
||||
Downsampling improves the temporal resolution but reduces the frequency range.
|
||||
n_fft: Size of the FFT window for spectrogram computation.
|
||||
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
|
||||
|
||||
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
|
||||
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
|
||||
n_mels: Number of mel filter banks, computed automatically to match the output_height.
|
||||
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
|
||||
mel_spectrogram_transform: The complete audio processing pipeline.
|
||||
"""
|
||||
|
||||
output_height: int = 224
|
||||
output_width: int = 224
|
||||
output_channels: int = 3
|
||||
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
|
||||
|
||||
input_sample_rate: int = 48000
|
||||
intermediate_sample_rate: int = 16000
|
||||
|
||||
n_fft: int = 1024
|
||||
|
||||
# Parameters computed from other parameters at initialization
|
||||
hop_length: int = field(init=False)
|
||||
n_mels: int = field(init=False)
|
||||
mel_spectrogram_transform: Compose = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.hop_length = int(
|
||||
self.intermediate_sample_rate * self.input_audio_chunk_duration
|
||||
- self.n_fft // self.output_width
|
||||
- 1
|
||||
)
|
||||
self.n_mels = self.output_height
|
||||
|
||||
self.mel_spectrogram_transform = Compose(
|
||||
[
|
||||
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
|
||||
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
|
||||
MelSpectrogram(
|
||||
sample_rate=self.intermediate_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
power=2, # Power spectrum
|
||||
),
|
||||
Lambda(
|
||||
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
|
||||
), # Convert to decibels
|
||||
Resize(
|
||||
(self.output_height, self.output_width)
|
||||
), # Resize spectrogram to output_height×output_width
|
||||
Lambda(
|
||||
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
|
||||
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
|
||||
]
|
||||
)
|
||||
|
||||
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Processes audio data contained in the provided observation.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
|
||||
# Process single audio observation
|
||||
if OBS_AUDIO in processed_obs:
|
||||
audio_data = processed_obs[OBS_AUDIO]
|
||||
if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples
|
||||
processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data)
|
||||
|
||||
# Process multiple audio observations
|
||||
for key, value in processed_obs.items():
|
||||
if (
|
||||
key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3
|
||||
): # Batch, Channels, Samples
|
||||
processed_obs[key] = self.mel_spectrogram_transform(value)
|
||||
|
||||
return processed_obs
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
return self._process_observation(observation)
|
||||
@@ -25,7 +25,7 @@ from dataclasses import dataclass, field
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition, PolicyAction
|
||||
from .pipeline import (
|
||||
@@ -88,6 +88,8 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
- State vectors (1D tensors).
|
||||
- Single images (3D tensors).
|
||||
- Dictionaries of multiple images (3D tensors).
|
||||
- Single audio waveforms (2D tensors).
|
||||
- Dictionaries of multiple audio waveforms (2D tensors).
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
@@ -117,6 +119,18 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
# Process single audio observation - add batch dim if 2D
|
||||
if OBS_AUDIO in observation:
|
||||
audio_value = observation[OBS_AUDIO]
|
||||
if isinstance(audio_value, Tensor) and audio_value.dim() == 2:
|
||||
observation[OBS_AUDIO] = audio_value.unsqueeze(0)
|
||||
|
||||
# Process multiple audio observations - add batch dim if 2D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -34,6 +34,13 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
raise ValueError(
|
||||
f"Specifying '{attr}' is required for the camera to be used in a robot"
|
||||
)
|
||||
if hasattr(self, "microphones") and self.microphones:
|
||||
for _, config in self.microphones.items():
|
||||
for attr in ["sample_rate", "channels"]:
|
||||
if getattr(config, attr) is None:
|
||||
raise ValueError(
|
||||
f"Specifying '{attr}' is required for the microphone to be used in a robot"
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
||||
@@ -24,7 +24,8 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
@@ -99,6 +100,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
"""Check if robot is connected to SDK."""
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""Connect to robot via Frodobots SDK.
|
||||
|
||||
@@ -109,8 +111,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
DeviceAlreadyConnectedError: If robot is already connected
|
||||
DeviceNotConnectedError: If cannot connect to SDK server
|
||||
"""
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
|
||||
|
||||
# Verify SDK is running and accessible
|
||||
try:
|
||||
@@ -197,6 +197,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: float,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
@@ -223,8 +224,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
All SDK values are normalized to appropriate ranges for dataset recording.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
observation = {}
|
||||
|
||||
@@ -255,6 +254,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
return observation
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
@@ -272,8 +272,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
@@ -291,6 +289,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from robot.
|
||||
|
||||
@@ -299,8 +298,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Stop the robot before disconnecting
|
||||
try:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -31,6 +32,8 @@ class HopeJrHandConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.side not in ["right", "left"]:
|
||||
@@ -49,3 +52,5 @@ class HopeJrArmConfig(RobotConfig):
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -82,13 +82,12 @@ class HopeJrArm(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect(handshake=False)
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -128,10 +127,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
@@ -147,12 +144,17 @@ class HopeJrArm(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
@@ -165,10 +167,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_hope_jr import HopeJrHandConfig
|
||||
@@ -118,10 +118,8 @@ class HopeJrHand(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
@@ -159,10 +157,8 @@ class HopeJrHand(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
@@ -179,20 +175,23 @@ class HopeJrHand(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -35,5 +36,8 @@ class KochFollowerConfig(RobotConfig):
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# microphones
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -19,13 +19,14 @@ import time
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -61,6 +62,7 @@ class KochFollower(Robot):
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
@@ -72,9 +74,16 @@ class KochFollower(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -82,15 +91,18 @@ class KochFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -102,6 +114,9 @@ class KochFollower(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -182,10 +197,8 @@ class KochFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -200,8 +213,16 @@ class KochFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -215,8 +236,6 @@ class KochFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -231,12 +250,12 @@ class KochFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -45,6 +46,8 @@ class LeKiwiConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -92,5 +95,7 @@ class LeKiwiClientConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
polling_timeout_ms: int = 15
|
||||
connect_timeout_s: int = 5
|
||||
|
||||
@@ -23,13 +23,14 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -73,6 +74,7 @@ class LeKiwi(Robot):
|
||||
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
|
||||
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _state_ft(self) -> dict[str, type]:
|
||||
@@ -97,9 +99,16 @@ class LeKiwi(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -107,12 +116,14 @@ class LeKiwi(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -123,6 +134,9 @@ class LeKiwi(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -339,10 +353,8 @@ class LeKiwi(Robot):
|
||||
"theta.vel": theta,
|
||||
} # m/s and deg/s
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read actuators position for arm and vel for base
|
||||
start = time.perf_counter()
|
||||
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||
@@ -368,8 +380,16 @@ class LeKiwi(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration.
|
||||
|
||||
@@ -383,8 +403,6 @@ class LeKiwi(Robot):
|
||||
Returns:
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")}
|
||||
base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")}
|
||||
@@ -412,13 +430,13 @@ class LeKiwi(Robot):
|
||||
self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5)
|
||||
logger.info("Base motors stopped")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_base()
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -18,13 +18,15 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from time import perf_counter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_lekiwi import LeKiwiClientConfig
|
||||
@@ -57,8 +59,9 @@ class LeKiwiClient(Robot):
|
||||
self.zmq_observation_socket = None
|
||||
|
||||
self.last_frames = {}
|
||||
|
||||
self.last_remote_state = {}
|
||||
self.last_frame_timestamp = None
|
||||
self.last_frame_delay = 0.0
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
@@ -96,9 +99,13 @@ class LeKiwiClient(Robot):
|
||||
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
|
||||
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
||||
|
||||
@cached_property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {name: (cfg.sample_rate, cfg.channels) for name, cfg in self.config.microphones.items()}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -112,14 +119,10 @@ class LeKiwiClient(Robot):
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
zmq = self._zmq
|
||||
self.zmq_context = zmq.Context()
|
||||
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
@@ -138,6 +141,7 @@ class LeKiwiClient(Robot):
|
||||
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
||||
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
self._is_connected = True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
@@ -170,6 +174,8 @@ class LeKiwiClient(Robot):
|
||||
if last_msg is None:
|
||||
logging.warning("Poller indicated data, but failed to retrieve message.")
|
||||
|
||||
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
return last_msg
|
||||
|
||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||
@@ -206,14 +212,16 @@ class LeKiwiClient(Robot):
|
||||
|
||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images
|
||||
# Decode images and audio data
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
for cam_name, image_b64 in observation.items():
|
||||
if cam_name not in self._cameras_ft:
|
||||
continue
|
||||
frame = self._decode_image_from_b64(image_b64)
|
||||
if frame is not None:
|
||||
current_frames[cam_name] = frame
|
||||
for frame_name, frame_data in observation.items():
|
||||
if frame_name in self._cameras_ft:
|
||||
image = self._decode_image_from_b64(frame_data)
|
||||
if image is not None:
|
||||
current_frames[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
if frame_data is not None:
|
||||
current_frames[frame_name] = frame_data
|
||||
|
||||
return current_frames, obs_dict
|
||||
|
||||
@@ -252,23 +260,32 @@ class LeKiwiClient(Robot):
|
||||
|
||||
return new_frames, new_state
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
frames, obs_dict = self._get_data()
|
||||
|
||||
# Loop over each configured camera
|
||||
for cam_name, frame in frames.items():
|
||||
if frame is None:
|
||||
logging.warning("Frame is None")
|
||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = frame
|
||||
# Loop over each configured camera and microphone
|
||||
for frame_name, frame_data in frames.items():
|
||||
if frame_data is None:
|
||||
if frame_name in self._cameras_ft:
|
||||
logging.warning("Image frame is None")
|
||||
image = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
logging.warning("Audio frame is None")
|
||||
obs_dict[frame_name] = np.zeros(
|
||||
(
|
||||
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
|
||||
self._microphones_ft[frame_name][1],
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -307,6 +324,7 @@ class LeKiwiClient(Robot):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
@@ -318,10 +336,6 @@ class LeKiwiClient(Robot):
|
||||
Returns:
|
||||
np.ndarray: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space
|
||||
|
||||
@@ -332,13 +346,10 @@ class LeKiwiClient(Robot):
|
||||
action_sent[ACTION] = actions
|
||||
return action_sent
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Cleans ZMQ comms"""
|
||||
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"LeKiwi is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
self.zmq_observation_socket.close()
|
||||
self.zmq_cmd_socket.close()
|
||||
self.zmq_context.term()
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -84,6 +84,7 @@ class OmxFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
For OMX robots that come pre-calibrated:
|
||||
@@ -91,8 +92,6 @@ class OmxFollower(Robot):
|
||||
- This allows using pre-calibrated robots without manual calibration
|
||||
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -165,10 +164,8 @@ class OmxFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -185,6 +182,7 @@ class OmxFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -198,8 +196,6 @@ class OmxFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -214,10 +210,8 @@ class OmxFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -39,6 +40,9 @@ class SOFollowerConfig:
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# microphones
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@@ -20,13 +20,14 @@ from functools import cached_property
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -62,6 +63,7 @@ class SOFollower(Robot):
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
@@ -73,9 +75,16 @@ class SOFollower(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -83,15 +92,18 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -103,6 +115,9 @@ class SOFollower(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -176,10 +191,8 @@ class SOFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -194,8 +207,16 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -209,8 +230,6 @@ class SOFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -225,13 +244,13 @@ class SOFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ import draccus
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
|
||||
@@ -66,23 +66,23 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Convert image dataset to video format (saves locally):
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset and save with new repo_id:
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert and push to hub:
|
||||
Convert image dataset to video format and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Using JSON config file:
|
||||
@@ -92,24 +92,19 @@ Using JSON config file:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import write_stats, write_tasks
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -138,8 +133,8 @@ class RemoveFeatureConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertToVideoConfig:
|
||||
type: str = "convert_to_video"
|
||||
class ConvertImageToVideoConfig:
|
||||
type: str = "convert_image_to_video"
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
@@ -148,12 +143,16 @@ class ConvertToVideoConfig:
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
|
||||
operation: (
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig
|
||||
)
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -297,362 +296,7 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def encode_episode_videos(
|
||||
dataset: LeRobotDataset,
|
||||
new_meta: LeRobotDatasetMetadata,
|
||||
episode_index: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
temp_dir: Path,
|
||||
num_image_workers: int = 4,
|
||||
) -> dict[str, dict]:
|
||||
"""Encode videos for a single episode and return video metadata.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images
|
||||
new_meta: Metadata object for the new video dataset
|
||||
episode_index: Episode index to process
|
||||
vcodec: Video codec
|
||||
pix_fmt: Pixel format
|
||||
g: Group of pictures size
|
||||
crf: Constant rate factor
|
||||
fast_decode: Fast decode tuning
|
||||
temp_dir: Temporary directory for images
|
||||
num_image_workers: Number of workers for saving images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
video_metadata = {}
|
||||
fps = int(dataset.fps) # Convert to int for PyAV compatibility
|
||||
episode_length = dataset.meta.episodes["length"][episode_index]
|
||||
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
|
||||
|
||||
for img_key in img_keys:
|
||||
# Save images temporarily
|
||||
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
|
||||
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
|
||||
|
||||
# Determine chunk and file indices
|
||||
# For simplicity, we'll put each episode in its own file
|
||||
chunk_idx = episode_index // new_meta.chunks_size
|
||||
file_idx = episode_index % new_meta.chunks_size
|
||||
|
||||
# Create video path in the new dataset structure
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encode video
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Store video metadata
|
||||
video_metadata[img_key] = {
|
||||
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||
f"videos/{img_key}/file_index": file_idx,
|
||||
f"videos/{img_key}/from_timestamp": 0.0,
|
||||
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||
}
|
||||
|
||||
return video_metadata
|
||||
|
||||
|
||||
def convert_dataset_to_videos(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-based dataset to video-based dataset.
|
||||
|
||||
Creates a new LeRobotDataset with videos instead of images, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process each episode
|
||||
all_episode_metadata = []
|
||||
|
||||
try:
|
||||
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
|
||||
# Get episode metadata from source
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
|
||||
# Encode videos for this episode
|
||||
video_metadata = encode_episode_videos(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
episode_index=ep_idx,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
temp_dir=temp_dir,
|
||||
num_image_workers=num_workers,
|
||||
)
|
||||
|
||||
# Build episode metadata
|
||||
episode_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": src_episode["length"],
|
||||
"dataset_from_index": ep_idx * src_episode["length"],
|
||||
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
|
||||
}
|
||||
|
||||
# Add video metadata
|
||||
for img_key in img_keys:
|
||||
episode_meta.update(video_metadata[img_key])
|
||||
|
||||
# Add data chunk/file info (using same structure as source)
|
||||
if "data/chunk_index" in src_episode:
|
||||
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
|
||||
all_episode_metadata.append(episode_meta)
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
from lerobot.datasets.utils import write_info
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
@@ -664,8 +308,12 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
output_repo_id = cfg.new_repo_id
|
||||
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = cfg.new_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
@@ -675,12 +323,15 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = output_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
new_dataset = convert_dataset_to_videos(
|
||||
new_dataset = convert_image_to_video_dataset(
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
@@ -691,6 +342,8 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
max_frames_per_batch=getattr(cfg.operation, "max_frames_per_batch", None),
|
||||
)
|
||||
|
||||
logging.info("Video dataset created successfully!")
|
||||
@@ -718,8 +371,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "convert_to_video":
|
||||
handle_convert_to_video(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
|
||||
@@ -64,11 +64,14 @@ lerobot-record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -81,8 +84,23 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||
build_dataset_frame,
|
||||
combine_feature_dicts,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.microphones import (
|
||||
MicrophoneConfig, # noqa: F401
|
||||
)
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
)
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
@@ -120,6 +138,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.audio_utils import rolling_vstack
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
@@ -279,6 +298,13 @@ def record_loop(
|
||||
display_data: bool = False,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if display_data:
|
||||
init_rerun(
|
||||
session_name="recording",
|
||||
robot=robot,
|
||||
reset_time=True,
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
@@ -313,6 +339,36 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Create a buffer for audio observations (shifting window of fixed size over audio samples)
|
||||
if robot.microphones and (policy is not None or dataset is not None):
|
||||
audio_buffer = {
|
||||
microphone_name: np.zeros(
|
||||
(int(microphone.sample_rate * DEFAULT_AUDIO_CHUNK_DURATION), len(microphone.channels))
|
||||
)
|
||||
for microphone_name, microphone in robot.microphones.items()
|
||||
}
|
||||
|
||||
if (
|
||||
dataset is not None and robot.name != "lekiwi"
|
||||
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||
dataset.add_microphones_recordings(robot.microphones)
|
||||
else:
|
||||
async_microphones_start_recording(robot.microphones)
|
||||
|
||||
# Fill audio buffers if needed
|
||||
if (
|
||||
robot.microphones
|
||||
and (policy is not None or dataset is not None)
|
||||
and DEFAULT_INITIAL_AUDIO_BUFFER_DURATION > 0.0
|
||||
):
|
||||
# This initial wait might be longer than the audio chunk duration to
|
||||
# (1) ensure that the audio buffers are filled with enough data
|
||||
# (2) add additional initial samples to the dataset in case of variable audio chunk duration during training
|
||||
precise_sleep(DEFAULT_INITIAL_AUDIO_BUFFER_DURATION)
|
||||
for microphone_name, microphone in robot.microphones.items():
|
||||
audio_chunk = microphone.read()
|
||||
audio_buffer[microphone_name] = rolling_vstack(audio_buffer[microphone_name], audio_chunk)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -333,8 +389,14 @@ def record_loop(
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# Transform instantaneous audio samples into a buffer of fixed size
|
||||
buffered_observation_frame = copy(observation_frame)
|
||||
for name in audio_buffer:
|
||||
# Add the audio buffer to the observation
|
||||
buffered_observation_frame[name] = rolling_vstack(audio_buffer[name], observation_frame[name])
|
||||
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
observation=buffered_observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
@@ -389,14 +451,26 @@ def record_loop(
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
observation=obs_processed,
|
||||
action=action_values,
|
||||
compress_images=display_compressed_images,
|
||||
log_time=time.perf_counter() - start_episode_t,
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
remaining_time = 1 / fps - dt_s
|
||||
if remaining_time > 0.0:
|
||||
print(f"Waiting {remaining_time:.2f} seconds to maintain {fps:.2f} Hz control loop frequency.")
|
||||
precise_sleep(remaining_time)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Inconsistent control loop frequency: {1 / dt_s:.2f} Hz < {fps:.2f} Hz. Try reducing the cameras resolution or FPS."
|
||||
)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
async_microphones_stop_recording(robot.microphones)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
@@ -61,6 +61,9 @@ import rerun as rr
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
@@ -143,8 +146,18 @@ def teleop_loop(
|
||||
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
||||
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
||||
"""
|
||||
if display_data:
|
||||
init_rerun(
|
||||
session_name="teleoperation",
|
||||
robot=robot,
|
||||
reset_time=True,
|
||||
)
|
||||
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.start_recording()
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
@@ -176,6 +189,7 @@ def teleop_loop(
|
||||
observation=obs_transition,
|
||||
action=teleop_action,
|
||||
compress_images=display_compressed_images,
|
||||
log_time=time.perf_counter() - start,
|
||||
)
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
@@ -192,7 +206,10 @@ def teleop_loop(
|
||||
move_cursor_up(1)
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
break
|
||||
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -148,92 +148,6 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def get_default_peft_configuration(policy_type):
|
||||
"""Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint."""
|
||||
|
||||
common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
|
||||
if policy_type == "smolvla":
|
||||
return {
|
||||
"target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))",
|
||||
"modules_to_save": [],
|
||||
}
|
||||
elif policy_type in ("pi0", "pi05"):
|
||||
return {
|
||||
"target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))",
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
return {"modules_to_save": None}
|
||||
|
||||
|
||||
def wrap_policy_in_peft_model(cfg, policy):
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model
|
||||
|
||||
# Disable all gradients because we'll only train the parameters selected by the PEFT method.
|
||||
# Layers that should receive gradients anyway need to be listed in `modules_to_save`.
|
||||
for p in policy.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
if not cfg.policy.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Supply a `policy.path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set "
|
||||
"`load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
peft_config_policy = get_default_peft_configuration(cfg.policy.type)
|
||||
peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {}
|
||||
peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT
|
||||
peft_method_type = PeftType[peft_config_cli["method_type"].upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Handle specific CLI overrides
|
||||
for key in ["target_modules", "modules_to_save", "r"]:
|
||||
if peft_config_cli[key] is not None:
|
||||
peft_config_policy[key] = peft_config_cli[key]
|
||||
|
||||
if "target_modules" not in peft_config_policy:
|
||||
raise ValueError(
|
||||
f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually."
|
||||
)
|
||||
|
||||
# Init method depends on the used PEFT method, your specific PEFT method
|
||||
# might not be considered here, in that case an error is raised.
|
||||
if peft_config_cli["init_type"] is not None:
|
||||
if peft_method_type == "LORA":
|
||||
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
|
||||
elif peft_method_type == "MISS":
|
||||
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
|
||||
)
|
||||
|
||||
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
|
||||
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
|
||||
# adapter, not the base model.
|
||||
if policy.config.pretrained_path:
|
||||
policy.name_or_path = str(policy.config.pretrained_path)
|
||||
|
||||
# Finally wrap the policy in a PEFT model
|
||||
policy = get_peft_model(
|
||||
policy,
|
||||
peft_config_cls(**peft_config_policy),
|
||||
)
|
||||
|
||||
# Make sure that the config is tagged as using PEFT so that the loading code can take the
|
||||
# appropriate steps to use the adapter weights and the PEFT config instead of the full model weights.
|
||||
policy.config.use_peft = True
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"""
|
||||
@@ -263,8 +177,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
# Note (maractin): cfg.policy may be None before validate() fully loads from pretrained_path
|
||||
force_cpu = cfg.policy is not None and cfg.policy.device == "cpu"
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -312,9 +225,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
if is_main_process:
|
||||
logging.info("Creating env")
|
||||
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
if is_main_process:
|
||||
@@ -327,7 +239,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
if cfg.peft is not None:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
policy = wrap_policy_in_peft_model(cfg, policy)
|
||||
# Convert CLI peft config to dict for overrides
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -92,10 +92,8 @@ class BiSOLeader(Teleoperator):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -86,10 +86,8 @@ class GamepadTeleop(Teleoperator):
|
||||
self.gamepad = Gamepad()
|
||||
self.gamepad.start()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Update the controller to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from pprint import pformat
|
||||
import serial
|
||||
|
||||
from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -93,10 +93,8 @@ class HomunculusArm(Teleoperator):
|
||||
with self.serial_lock:
|
||||
return self.serial.is_open and self.thread.is_alive()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if not self.serial.is_open:
|
||||
self.serial.open()
|
||||
self.thread.start()
|
||||
@@ -299,20 +297,16 @@ class HomunculusArm(Teleoperator):
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
joint_positions = self._read()
|
||||
return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_event.set()
|
||||
self.thread.join(timeout=1)
|
||||
self.serial.close()
|
||||
|
||||
@@ -24,7 +24,7 @@ import serial
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.motors.motors_bus import MotorNormMode
|
||||
from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -119,10 +119,8 @@ class HomunculusGlove(Teleoperator):
|
||||
with self.serial_lock:
|
||||
return self.serial.is_open and self.thread.is_alive()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if not self.serial.is_open:
|
||||
self.serial.open()
|
||||
self.thread.start()
|
||||
@@ -325,10 +323,8 @@ class HomunculusGlove(Teleoperator):
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
joint_positions = self._read()
|
||||
return homunculus_glove_to_hope_jr_hand(
|
||||
{f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
@@ -337,10 +333,8 @@ class HomunculusGlove(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_event.set()
|
||||
self.thread.join(timeout=1)
|
||||
self.serial.close()
|
||||
|
||||
@@ -22,7 +22,7 @@ from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -86,12 +86,8 @@ class KeyboardTeleop(Teleoperator):
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"Keyboard is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
if PYNPUT_AVAILABLE:
|
||||
logging.info("pynput is available - enabling local keyboard listener.")
|
||||
self.listener = keyboard.Listener(
|
||||
@@ -125,14 +121,10 @@ class KeyboardTeleop(Teleoperator):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
)
|
||||
|
||||
self._drain_pressed_keys()
|
||||
|
||||
# Generate action based on current key states
|
||||
@@ -144,11 +136,8 @@ class KeyboardTeleop(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`."
|
||||
)
|
||||
if self.listener is not None:
|
||||
self.listener.stop()
|
||||
|
||||
@@ -182,12 +171,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
)
|
||||
|
||||
self._drain_pressed_keys()
|
||||
delta_x = 0.0
|
||||
delta_y = 0.0
|
||||
@@ -375,6 +360,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
# Only remove key if it's being released
|
||||
self.current_pressed.pop(key_char, None)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get the current action based on pressed keys.
|
||||
@@ -384,11 +370,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
"""
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardRoverTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
)
|
||||
|
||||
self._drain_pressed_keys()
|
||||
|
||||
linear_velocity = 0.0
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_koch_leader import KochLeaderConfig
|
||||
@@ -69,10 +69,8 @@ class KochLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -161,10 +159,8 @@ class KochLeader(Teleoperator):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -176,9 +172,7 @@ class KochLeader(Teleoperator):
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_omx_leader import OmxLeaderConfig
|
||||
@@ -68,10 +68,8 @@ class OmxLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -142,10 +140,8 @@ class OmxLeader(Teleoperator):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -157,9 +153,7 @@ class OmxLeader(Teleoperator):
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -28,7 +28,7 @@ from teleop import Teleop
|
||||
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -81,10 +81,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self._group is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
@@ -164,10 +162,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose or not self.is_calibrated:
|
||||
return {}
|
||||
@@ -207,10 +203,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._group = None
|
||||
|
||||
|
||||
@@ -230,10 +224,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self._teleop is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
@@ -321,10 +313,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
@@ -356,10 +346,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
|
||||
@@ -26,7 +26,8 @@ if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
else:
|
||||
ReachySDK = None
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
@@ -126,10 +127,8 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
|
||||
if not self.is_connected:
|
||||
@@ -146,12 +145,10 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
joint_action: dict[str, float] = {}
|
||||
vel_action: dict[str, float] = {}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_so_leader import SOLeaderTeleopConfig
|
||||
@@ -66,10 +66,8 @@ class SOLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -139,10 +137,8 @@ class SOLeader(Teleoperator):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -154,10 +150,8 @@ class SOLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rolling_vstack(buffer: np.ndarray, new_data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rolling implementation of numpy.vstack to add new data in at the end of a fixed shape buffer in a rolling fashion.
|
||||
|
||||
Args:
|
||||
buffer: The *fixed* shape buffer to update.
|
||||
new_data: The new data to add to the buffer.
|
||||
|
||||
Returns:
|
||||
The updated buffer.
|
||||
"""
|
||||
|
||||
buffer_size = buffer.shape[0]
|
||||
# Remove as many old audio samples as needed
|
||||
buffer[: -len(new_data)] = buffer[len(new_data) :]
|
||||
# Add new audio samples, only the newest if the buffer is already full
|
||||
buffer[-len(new_data) :] = new_data[-buffer_size:]
|
||||
return buffer
|
||||
@@ -23,6 +23,7 @@ OBS_ENV_STATE = OBS_STR + ".environment_state"
|
||||
OBS_STATE = OBS_STR + ".state"
|
||||
OBS_IMAGE = OBS_STR + ".image"
|
||||
OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_AUDIO = OBS_STR + ".audio"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
@@ -102,7 +102,7 @@ def predict_action(
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
# Convert to pytorch format: normalizing and permuting (channel first)
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
observation = preprocessor(observation)
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
|
||||
def check_if_not_connected(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__} is not connected. Run `.connect()` first."
|
||||
)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_if_already_connected(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.__class__.__name__} is already connected.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -30,3 +30,22 @@ class DeviceAlreadyConnectedError(ConnectionError):
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeviceNotRecordingError(Exception):
|
||||
"""Exception raised when the robot device is not recording."""
|
||||
|
||||
def __init__(self, message="This robot device is not recording. Try calling `start_recording()` first."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeviceAlreadyRecordingError(Exception):
|
||||
"""Exception raised when the robot device is already recording."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is already recording. Try not calling `start_recording()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -21,12 +21,23 @@ from typing import Any
|
||||
from draccus.choice_types import ChoiceRegistry
|
||||
|
||||
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
def is_package_available(
|
||||
pkg_name: str, import_name: str | None = None, return_version: bool = False
|
||||
) -> tuple[bool, str] | bool:
|
||||
"""
|
||||
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
|
||||
Args:
|
||||
pkg_name: The name of the package as installed via pip (e.g. "python-can").
|
||||
import_name: The actual name used to import the package (e.g. "can").
|
||||
Defaults to pkg_name if not provided.
|
||||
return_version: Whether to return the version string.
|
||||
"""
|
||||
if import_name is None:
|
||||
import_name = pkg_name
|
||||
|
||||
# Check if the module spec exists using the import name
|
||||
package_exists = importlib.util.find_spec(import_name) is not None
|
||||
package_version = "N/A"
|
||||
if package_exists:
|
||||
try:
|
||||
@@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
# Fallback method: Only for "torch" and versions containing "dev"
|
||||
if pkg_name == "torch":
|
||||
try:
|
||||
package = importlib.import_module(pkg_name)
|
||||
package = importlib.import_module(import_name)
|
||||
temp_version = getattr(package, "__version__", "N/A")
|
||||
# Check if the version contains "dev"
|
||||
if "dev" in temp_version:
|
||||
@@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
except ImportError:
|
||||
# If the package can't be imported, it's not available
|
||||
package_exists = False
|
||||
elif pkg_name == "grpc":
|
||||
package = importlib.import_module(pkg_name)
|
||||
package_version = getattr(package, "__version__", "N/A")
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from multiprocessing import Lock, Value, shared_memory
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SharedArray:
|
||||
"""
|
||||
A SharedArray is a numpy array shared between multiple processes in a shared_memory object.
|
||||
- Data is written to the array using the `write` method, which appends data to the array.
|
||||
- Data is read from the array (and eventually flushed) using the `read` method, which copies the _whole_ array.
|
||||
|
||||
SharedArray offers quasi-instantaneous array-wide read and flush capabilities in comparison to Queues, but has a limited size defined at initialization.
|
||||
|
||||
Example:
|
||||
_Main_process_
|
||||
shared_array = SharedArray(shape=(10, 10), dtype=np.dtype("float32"))
|
||||
local_array = shared_array.get_local_array()
|
||||
shared_array.write(local_array, np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
|
||||
_Child_process_
|
||||
local_array = shared_array.get_local_array()
|
||||
data = shared_array.read(local_array, flush=True)
|
||||
"""
|
||||
|
||||
def __init__(self, shape: tuple[int], dtype: np.dtype | str):
|
||||
"""
|
||||
Initialize a SharedArray.
|
||||
|
||||
Args:
|
||||
shape: The shape of the shared array.
|
||||
dtype: The dtype of the shared array.
|
||||
"""
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=np.prod(shape) * np.dtype(dtype).itemsize
|
||||
)
|
||||
self.read_index = Value("i", 0)
|
||||
self.lock = Lock()
|
||||
|
||||
def get_local_array(self) -> np.ndarray:
|
||||
"""
|
||||
Get a process-local instance of the shared array.
|
||||
|
||||
Returns:
|
||||
A process-local instance of the shared array.
|
||||
"""
|
||||
return np.ndarray(self.shape, dtype=np.dtype(self.dtype), buffer=self.shared_memory.buf)
|
||||
|
||||
def delete(self):
|
||||
"""
|
||||
Delete the shared array.
|
||||
"""
|
||||
self.shared_memory.close()
|
||||
self.shared_memory.unlink()
|
||||
|
||||
def write(self, local_array: np.ndarray, data: np.ndarray):
|
||||
"""
|
||||
Write data to the shared array.
|
||||
|
||||
Args:
|
||||
local_array: The process-local instance of the shared array to write to.
|
||||
data: The data to write to the shared array.
|
||||
"""
|
||||
with self.lock:
|
||||
local_array[self.read_index.value : self.read_index.value + len(data)] = data
|
||||
self.read_index.value += len(data)
|
||||
|
||||
def read(self, local_array: np.ndarray, flush: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Read data from the shared array.
|
||||
|
||||
Args:
|
||||
local_array: The process-local instance of the shared array to read from.
|
||||
flush: Whether to flush the shared array after reading.
|
||||
"""
|
||||
with self.lock:
|
||||
data = np.copy(local_array[: self.read_index.value])
|
||||
if flush:
|
||||
self.read_index.value = 0
|
||||
return data
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the read index to 0.
|
||||
"""
|
||||
with self.lock:
|
||||
self.read_index.value = 0
|
||||
@@ -14,17 +14,25 @@
|
||||
|
||||
import numbers
|
||||
import os
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots import Robot
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(
|
||||
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
|
||||
session_name: str = "lerobot_control_loop",
|
||||
ip: str | None = None,
|
||||
port: int | None = None,
|
||||
robot: Robot | None = None,
|
||||
reset_time: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Rerun SDK for visualizing the control loop.
|
||||
@@ -33,16 +41,26 @@ def init_rerun(
|
||||
session_name: Name of the Rerun session.
|
||||
ip: Optional IP for connecting to a Rerun server.
|
||||
port: Optional port for connecting to a Rerun server.
|
||||
robot: A Robot object. If provided, Rerun will be initialized with a blueprint that includes the object's cameras and microphones.
|
||||
reset_time: Whether to reset the timer "episode_time" to 0.
|
||||
"""
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
rr.init(session_name)
|
||||
rr.init(
|
||||
application_id=session_name,
|
||||
recording_id=uuid4(),
|
||||
)
|
||||
if robot is not None:
|
||||
rr.send_blueprint(build_rerun_blueprint(robot))
|
||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||
if ip and port:
|
||||
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
|
||||
else:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
if reset_time:
|
||||
rr.set_time("episode_time", timestamp=0.0)
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
|
||||
@@ -50,10 +68,47 @@ def _is_scalar(x):
|
||||
)
|
||||
|
||||
|
||||
def build_rerun_blueprint(robot: Robot) -> rr.blueprint.Grid:
|
||||
""" "
|
||||
Builds a Rerun blueprint for optimized visualization of the robot's observations and actions :
|
||||
- Time series views for all scalar observations and actions (e.g. position, velocity, torque, etc.).
|
||||
- Spatial 2D views for all camera observations.
|
||||
- Time series views for all microphone observations.
|
||||
|
||||
Args:
|
||||
robot: A Robot object.
|
||||
Returns:
|
||||
A Rerun blueprint.
|
||||
"""
|
||||
contents = [
|
||||
rr.blueprint.TimeSeriesView(
|
||||
origin="data",
|
||||
plot_legend=rr.blueprint.PlotLegend(visible=True),
|
||||
)
|
||||
]
|
||||
if robot.microphones:
|
||||
contents += [
|
||||
rr.blueprint.TimeSeriesView(
|
||||
origin="audio",
|
||||
plot_legend=rr.blueprint.PlotLegend(visible=True),
|
||||
)
|
||||
]
|
||||
if robot.cameras:
|
||||
contents += [
|
||||
rr.blueprint.Spatial2DView(
|
||||
origin=OBS_PREFIX + camera_name,
|
||||
)
|
||||
for camera_name in robot.cameras
|
||||
]
|
||||
|
||||
return rr.blueprint.Grid(*contents)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
observation: RobotObservation | None = None,
|
||||
action: RobotAction | None = None,
|
||||
compress_images: bool = False,
|
||||
log_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Logs observation and action data to Rerun for real-time visualization.
|
||||
@@ -72,7 +127,13 @@ def log_rerun_data(
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||
log_time: The time to log the data in the "episode_time" timeline.
|
||||
If None, the current time is used in Rerun's default timeline.
|
||||
"""
|
||||
if log_time is None:
|
||||
log_time = time.perf_counter()
|
||||
rr.set_time("episode_time", timestamp=log_time)
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
@@ -80,15 +141,41 @@ def log_rerun_data(
|
||||
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
rr.log("data/" + key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
# Convert samples x channels -> channels x samples when needed
|
||||
elif arr.ndim == 2 and arr.shape[1] < arr.shape[0]:
|
||||
arr = np.transpose(arr, (1, 0))
|
||||
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
elif arr.ndim == 2:
|
||||
for i, channel_arr in enumerate(arr):
|
||||
rr.send_columns(
|
||||
"audio/"
|
||||
+ key
|
||||
+ f"_channel_{i}", # TODO(CarolinePascal): Get actual channel number/name
|
||||
indexes=[
|
||||
rr.TimeColumn(
|
||||
"episode_time",
|
||||
timestamp=log_time
|
||||
+ np.linspace(
|
||||
-DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
0,
|
||||
len(channel_arr),
|
||||
endpoint=False,
|
||||
),
|
||||
)
|
||||
],
|
||||
columns=rr.Scalars.columns(scalars=channel_arr),
|
||||
)
|
||||
elif arr.ndim == 3:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
@@ -100,13 +187,13 @@ def log_rerun_data(
|
||||
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
rr.log("data/" + key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
|
||||
@@ -62,7 +62,7 @@ class MockPolicy:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@require_package("grpc")
|
||||
@require_package("grpcio", "grpc")
|
||||
def policy_server():
|
||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
|
||||
@@ -57,6 +57,8 @@ def _check_component_availability(component_type, available_components, make_com
|
||||
print("\nNo physical device detected.")
|
||||
elif isinstance(e, ValueError) and "camera_index" in str(e):
|
||||
print("\nNo physical camera detected.")
|
||||
elif isinstance(e, ValueError) and "microphone_index" in str(e):
|
||||
print("\nNo physical microphone detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
@@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
|
||||
|
||||
This verifies the fix for a bug where image columns were written with a generic
|
||||
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
|
||||
the proper Image() feature type, causing HuggingFace Hub viewer to display
|
||||
raw dict objects instead of image thumbnails.
|
||||
"""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
# Check that parquet files have proper Image schema
|
||||
data_dir = aggr_ds.root / "data"
|
||||
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
|
||||
|
||||
for parquet_file in parquet_files:
|
||||
# Load with HuggingFace datasets to check schema
|
||||
ds = datasets.Dataset.from_parquet(str(parquet_file))
|
||||
|
||||
for image_key in image_keys:
|
||||
feature = ds.features.get(image_key)
|
||||
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
|
||||
assert isinstance(feature, datasets.Image), (
|
||||
f"Image key '{image_key}' should have Image() feature type, "
|
||||
f"but got {type(feature).__name__}: {feature}. "
|
||||
"This indicates image schema was not preserved during aggregation."
|
||||
)
|
||||
|
||||
|
||||
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that image frames are correctly preserved after aggregation."""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
def images_equal(img1, img2):
|
||||
return torch.allclose(img1, img2)
|
||||
|
||||
# Test the section corresponding to the first dataset (ds_0)
|
||||
for i in range(len(ds_0)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test the section corresponding to the second dataset (ds_1)
|
||||
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_1"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
This test specifically verifies that:
|
||||
1. Image-based datasets can be aggregated correctly
|
||||
2. The HuggingFace Image() feature type is preserved in parquet files
|
||||
3. Image data integrity is maintained across aggregation
|
||||
4. Images can be properly decoded after aggregation
|
||||
|
||||
This catches the bug where to_parquet_with_hf_images() was not passing
|
||||
the features schema, causing image columns to be written as generic
|
||||
struct types instead of Image() types.
|
||||
"""
|
||||
ds_0_num_frames = 50
|
||||
ds_1_num_frames = 75
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
|
||||
aggr_root=tmp_path / "image_aggr",
|
||||
)
|
||||
|
||||
# Load the aggregated dataset
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
|
||||
|
||||
# Verify aggregated dataset has image keys
|
||||
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
|
||||
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
|
||||
|
||||
# Run standard aggregation assertions
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
for image_key in aggr_ds.meta.image_keys:
|
||||
img = sample_item[image_key]
|
||||
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
|
||||
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
|
||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
@@ -26,16 +26,22 @@ from lerobot.datasets.compute_stats import (
|
||||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_audio_from_data,
|
||||
sample_audio_from_path,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
def mock_load_audio(path):
|
||||
return np.ones((16000, 2), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
@@ -73,6 +79,25 @@ def test_sample_images(mock_load):
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||
def test_sample_audio_from_path(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
audio_samples = sample_audio_from_path(audio_path)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_sample_audio_from_data():
|
||||
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||
audio_samples = sample_audio_from_data(audio_data)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
@@ -81,6 +106,14 @@ def test_get_feature_stats_images():
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_audio():
|
||||
data = np.random.uniform(-1, 1, (16000, 2))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
@@ -145,20 +178,27 @@ def test_get_feature_stats_single_value():
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
||||
OBS_AUDIO: "audio.wav",
|
||||
OBS_STATE: np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
OBS_IMAGE: {"dtype": "image"},
|
||||
OBS_AUDIO: {"dtype": "audio"},
|
||||
OBS_STATE: {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
with (
|
||||
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
||||
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert OBS_IMAGE in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == 100
|
||||
assert stats[OBS_STATE]["count"].item() == 100
|
||||
assert OBS_IMAGE in stats and OBS_AUDIO in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == estimate_num_samples(100)
|
||||
assert stats[OBS_AUDIO]["count"].item() == estimate_num_samples(16000)
|
||||
assert stats[OBS_STATE]["count"].item() == estimate_num_samples(100)
|
||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||
assert stats[OBS_AUDIO]["mean"].shape == (1, 2)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1050,7 +1050,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos(tmp_path):
|
||||
def test_convert_image_to_video_dataset(tmp_path):
|
||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
@@ -1071,7 +1071,7 @@ def test_convert_dataset_to_videos(tmp_path):
|
||||
assert "observation.image" in source_dataset.meta.features
|
||||
|
||||
# Convert to video dataset (only first 2 episodes for speed)
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
video_dataset = convert_image_to_video_dataset(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video",
|
||||
@@ -1113,7 +1113,7 @@ def test_convert_dataset_to_videos(tmp_path):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
|
||||
def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
|
||||
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
@@ -1132,7 +1132,7 @@ def test_convert_dataset_to_videos_subset_episodes(tmp_path):
|
||||
# Convert only episode 0 to video (subset of loaded episodes)
|
||||
episode_indices = [0]
|
||||
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
video_dataset = convert_image_to_video_dataset(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video_subset",
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from soundfile import write
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
@@ -37,6 +38,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
@@ -49,7 +51,13 @@ from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
DUMMY_CHW,
|
||||
DUMMY_HWC,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -70,6 +78,36 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"audio": {
|
||||
"dtype": "audio",
|
||||
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"audio": {
|
||||
"dtype": "audio",
|
||||
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
@@ -352,6 +390,137 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed for image features after saving episode."""
|
||||
# Image feature: images should be deleted after saving episode
|
||||
image_key = "image"
|
||||
features_image = {
|
||||
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||
}
|
||||
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
|
||||
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_img.save_episode()
|
||||
img_dir = ds_img._get_image_file_dir(0, image_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
|
||||
|
||||
def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`."""
|
||||
# Video feature: when batch_encoding_size == 1 temporary images should be deleted
|
||||
vid_key = "video"
|
||||
features_video = {
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||
}
|
||||
|
||||
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
|
||||
ds_vid.batch_encoding_size = 1
|
||||
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_vid.save_episode()
|
||||
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
|
||||
assert not vid_img_dir.exists(), (
|
||||
"Temporary image directory should be removed when batch_encoding_size == 1"
|
||||
)
|
||||
|
||||
|
||||
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
|
||||
image_key = "image"
|
||||
vid_key = "video"
|
||||
features_mixed = {
|
||||
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds_mixed = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
|
||||
)
|
||||
ds_mixed.add_frame(
|
||||
{
|
||||
"image": np.random.rand(*DUMMY_CHW),
|
||||
"video": np.random.rand(*DUMMY_HWC),
|
||||
"task": "Dummy task",
|
||||
}
|
||||
)
|
||||
ds_mixed.save_episode()
|
||||
img_dir = ds_mixed._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
assert vid_img_dir.exists(), (
|
||||
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["audio"].shape == torch.Size(
|
||||
(
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
with pytest.raises(ValueError):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
with pytest.raises(ValueError):
|
||||
dataset.add_frame(
|
||||
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
|
||||
task="Dummy task",
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_file(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
# Create the audio file that should be created in the background by the Microphone class
|
||||
for audio_key in dataset.meta.audio_keys:
|
||||
fpath = dataset._get_raw_audio_file_path(0, audio_key)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
write(
|
||||
fpath,
|
||||
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["audio"].shape == torch.Size(
|
||||
(
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
@@ -391,6 +560,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
audio_keys = dataset.meta.audio_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
@@ -433,6 +603,11 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
for key in audio_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= -1.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
@@ -1392,3 +1567,202 @@ def test_valid_video_codecs_constant():
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 3
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Regression test for bug where delta_timestamps incorrectly marked all frames as padded when using episodes filter.
|
||||
|
||||
The bug occurred because _get_query_indices was using the relative index (idx) in the filtered dataset
|
||||
instead of the absolute index when comparing against episode boundaries (ep_start, ep_end).
|
||||
"""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Create 3 episodes with 10 frames each
|
||||
frames_per_episode = 10
|
||||
for ep_idx in range(3):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
|
||||
"action": torch.randn(2),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load only episode 1 (middle episode) with delta_timestamps
|
||||
delta_ts = {"observation.state": [0.0]} # Just the current frame
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1],
|
||||
delta_timestamps=delta_ts,
|
||||
)
|
||||
|
||||
# Verify the filtered dataset has the correct length
|
||||
assert len(filtered_dataset) == frames_per_episode
|
||||
|
||||
# Check that no frames are marked as padded (since delta=0 should always be valid)
|
||||
for idx in range(len(filtered_dataset)):
|
||||
frame = filtered_dataset[idx]
|
||||
assert frame["observation.state_is_pad"].item() is False, f"Frame {idx} incorrectly marked as padded"
|
||||
# Verify we're getting data from episode 1
|
||||
assert frame["episode_index"].item() == 1
|
||||
|
||||
|
||||
def test_delta_timestamps_padding_at_episode_boundaries(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that delta_timestamps correctly marks padding at episode boundaries when using episodes filter."""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test", features=features, use_videos=False, fps=10
|
||||
)
|
||||
|
||||
# Create 3 episodes with 5 frames each
|
||||
frames_per_episode = 5
|
||||
for ep_idx in range(3):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
|
||||
"action": torch.randn(2),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load only episode 1 with delta_timestamps that go beyond episode boundaries
|
||||
# fps=10, so 0.1s = 1 frame offset
|
||||
delta_ts = {"observation.state": [-0.2, -0.1, 0.0, 0.1, 0.2]} # -2, -1, 0, +1, +2 frames
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1],
|
||||
delta_timestamps=delta_ts,
|
||||
tolerance_s=0.04, # Slightly less than half a frame at 10fps
|
||||
)
|
||||
|
||||
assert len(filtered_dataset) == frames_per_episode
|
||||
|
||||
# Check padding at the start of the episode (first frame)
|
||||
first_frame = filtered_dataset[0]
|
||||
is_pad = first_frame["observation.state_is_pad"].tolist()
|
||||
# At frame 0 of episode 1: delta -2 and -1 should be padded, 0, +1, +2 should not
|
||||
assert is_pad == [True, True, False, False, False], f"First frame padding incorrect: {is_pad}"
|
||||
|
||||
# Check middle frame (no padding expected)
|
||||
mid_frame = filtered_dataset[2]
|
||||
is_pad = mid_frame["observation.state_is_pad"].tolist()
|
||||
assert is_pad == [False, False, False, False, False], f"Middle frame padding incorrect: {is_pad}"
|
||||
|
||||
# Check padding at the end of the episode (last frame)
|
||||
last_frame = filtered_dataset[4]
|
||||
is_pad = last_frame["observation.state_is_pad"].tolist()
|
||||
# At frame 4 of episode 1: delta -2, -1, 0 should not be padded, +1, +2 should be
|
||||
assert is_pad == [False, False, False, True, True], f"Last frame padding incorrect: {is_pad}"
|
||||
|
||||
|
||||
def test_delta_timestamps_multiple_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test delta_timestamps with multiple non-consecutive episodes selected."""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test", features=features, use_videos=False, fps=10
|
||||
)
|
||||
|
||||
# Create 5 episodes with 5 frames each
|
||||
frames_per_episode = 5
|
||||
for ep_idx in range(5):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load episodes 1 and 3 (non-consecutive)
|
||||
delta_ts = {"observation.state": [0.0]}
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1, 3],
|
||||
delta_timestamps=delta_ts,
|
||||
)
|
||||
|
||||
assert len(filtered_dataset) == 2 * frames_per_episode
|
||||
|
||||
# All frames should have valid (non-padded) data for delta=0
|
||||
for idx in range(len(filtered_dataset)):
|
||||
frame = filtered_dataset[idx]
|
||||
assert frame["observation.state_is_pad"].item() is False
|
||||
|
||||
# Verify we're getting the correct episodes
|
||||
episode_indices = [filtered_dataset[i]["episode_index"].item() for i in range(len(filtered_dataset))]
|
||||
expected_episodes = [1] * frames_per_episode + [3] * frames_per_episode
|
||||
assert episode_indices == expected_episodes
|
||||
|
||||
|
||||
def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that delta_timestamps returns the correct observation values, not just correct padding."""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (1,), "names": ["x"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test", features=features, use_videos=False, fps=10
|
||||
)
|
||||
|
||||
# Create 2 episodes with known values
|
||||
# Episode 0: frames with values 0, 1, 2, 3, 4
|
||||
# Episode 1: frames with values 10, 11, 12, 13, 14
|
||||
frames_per_episode = 5
|
||||
for ep_idx in range(2):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
value = ep_idx * 10 + frame_idx
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([value], dtype=torch.float32),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load episode 1 with delta that looks at previous frame
|
||||
delta_ts = {"observation.state": [-0.1, 0.0]} # Previous frame and current frame
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1],
|
||||
delta_timestamps=delta_ts,
|
||||
tolerance_s=0.04,
|
||||
)
|
||||
|
||||
# Check frame 2 of episode 1 (which has absolute index 7, value 12)
|
||||
frame = filtered_dataset[2]
|
||||
state_values = frame["observation.state"].tolist()
|
||||
# Should get [11, 12] - the previous and current values within episode 1
|
||||
assert state_values == [11.0, 12.0], f"Expected [11.0, 12.0], got {state_values}"
|
||||
|
||||
# Check first frame - previous frame should be clamped to episode start (padded)
|
||||
first_frame = filtered_dataset[0]
|
||||
state_values = first_frame["observation.state"].tolist()
|
||||
is_pad = first_frame["observation.state_is_pad"].tolist()
|
||||
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
||||
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
||||
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
||||
|
||||
Vendored
+13
@@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = {
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_MICROPHONE_FEATURES = {
|
||||
"laptop": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
|
||||
"phone": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
|
||||
}
|
||||
DEFAULT_SAMPLE_RATE = 48000
|
||||
DUMMY_AUDIO_CHANNELS = 2
|
||||
DUMMY_AUDIO_INFO = {
|
||||
"has_audio": True,
|
||||
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
|
||||
"audio.codec": "aac",
|
||||
"audio.channels": DUMMY_AUDIO_CHANNELS,
|
||||
"audio.channel_layout": "stereo",
|
||||
}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
||||
Vendored
+18
-1
@@ -28,6 +28,7 @@ from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -42,6 +43,7 @@ from lerobot.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MICROPHONE_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
@@ -130,6 +132,7 @@ def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
@@ -141,6 +144,7 @@ def features_factory():
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**audio_features,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
@@ -157,16 +161,19 @@ def info_factory(features_factory):
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_audio: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
audio_path: str = DEFAULT_AUDIO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
features = features_factory(motor_features, camera_features, audio_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
@@ -174,6 +181,7 @@ def info_factory(features_factory):
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_audio": total_audio,
|
||||
"chunks_size": chunks_size,
|
||||
"data_files_size_in_mb": data_files_size_in_mb,
|
||||
"video_files_size_in_mb": video_files_size_in_mb,
|
||||
@@ -181,6 +189,7 @@ def info_factory(features_factory):
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"audio_path": audio_path,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -204,6 +213,14 @@ def stats_factory():
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
elif dtype == "audio":
|
||||
stats[key] = {
|
||||
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
|
||||
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
|
||||
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
|
||||
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
|
||||
@@ -0,0 +1,532 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig
|
||||
from lerobot.microphones.portaudio.interface_sounddevice_sdk import (
|
||||
FakeSounddeviceSDKAdapter,
|
||||
SounddeviceSDKAdapter,
|
||||
)
|
||||
from lerobot.microphones.portaudio.microphone_portaudio import PortAudioMicrophone
|
||||
from lerobot.microphones.utils import async_microphones_start_recording, async_microphones_stop_recording
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
MODULE_PATH = "lerobot.microphones.portaudio.microphone_portaudio"
|
||||
RECORDING_DURATION = 1.0
|
||||
|
||||
LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS = (
|
||||
os.getenv("LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_sdk():
|
||||
"""Fixture to provide either real or fake SDK based on environment variable."""
|
||||
if LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS:
|
||||
return SounddeviceSDKAdapter()
|
||||
else:
|
||||
return FakeSounddeviceSDKAdapter()
|
||||
|
||||
|
||||
# Configuration Tests
|
||||
|
||||
|
||||
def test_config_creation():
|
||||
"""Test creating a valid configuration."""
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000, channels=[1, 2])
|
||||
assert config.microphone_index == 0
|
||||
assert config.sample_rate == 48000
|
||||
assert config.channels == [1, 2]
|
||||
|
||||
|
||||
def test_config_creation_missing_microphone_index():
|
||||
"""Test creating a configuration with missing microphone index."""
|
||||
with pytest.raises(TypeError):
|
||||
PortAudioMicrophoneConfig(sample_rate=48000, channels=[1, 2])
|
||||
|
||||
|
||||
def test_config_creation_missing_sample_rate():
|
||||
"""Test creating a configuration with missing sample rate."""
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, channels=[1, 2])
|
||||
assert config.sample_rate is None
|
||||
|
||||
|
||||
def test_config_creation_missing_channels():
|
||||
"""Test creating a configuration with missing channels."""
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000)
|
||||
assert config.channels is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_config(test_sdk):
|
||||
"""Fixture to provide a default configuration for input devices."""
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
return PortAudioMicrophoneConfig(
|
||||
microphone_index=device_info["index"],
|
||||
sample_rate=device_info["default_samplerate"],
|
||||
channels=np.arange(device_info["max_input_channels"]) + 1,
|
||||
)
|
||||
|
||||
|
||||
# Microphone Tests
|
||||
|
||||
|
||||
def test_find_microphones(test_sdk):
|
||||
"""Test finding microphones."""
|
||||
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=test_sdk)
|
||||
|
||||
for microphone in microphones:
|
||||
assert isinstance(microphone["index"], int)
|
||||
assert isinstance(microphone["name"], str)
|
||||
assert isinstance(microphone["sample_rate"], int)
|
||||
assert isinstance(microphone["channels"], np.ndarray)
|
||||
assert len(microphone["channels"]) > 0
|
||||
|
||||
|
||||
def test_init_defaults(default_config, test_sdk):
|
||||
"""Test microphone initialization with defaults."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert microphone is not None
|
||||
assert microphone.microphone_index == device_info["index"]
|
||||
assert microphone.sample_rate == device_info["default_samplerate"]
|
||||
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
|
||||
|
||||
def test_connect_success(default_config, test_sdk):
|
||||
"""Test successful connection."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_connect_empty_config(default_config, test_sdk):
|
||||
"""Test connection with empty config values."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = None
|
||||
config.channels = None
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert microphone.sample_rate == device_info["default_samplerate"]
|
||||
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
|
||||
|
||||
|
||||
def test_connect_already_connected(default_config, test_sdk):
|
||||
"""Test connecting when already connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_device(test_sdk):
|
||||
"""Test connecting with invalid device (output device)."""
|
||||
device_info = test_sdk.query_devices(kind="output")
|
||||
config = PortAudioMicrophoneConfig(
|
||||
microphone_index=device_info["index"],
|
||||
sample_rate=device_info["default_samplerate"],
|
||||
channels=np.arange(device_info["max_input_channels"]) + 1,
|
||||
)
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_index(default_config, test_sdk):
|
||||
"""Test connecting with invalid device index."""
|
||||
config = deepcopy(default_config)
|
||||
config.microphone_index = -1
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_sample_rate(default_config, test_sdk):
|
||||
"""Test connecting with invalid sample rate."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = -1
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_float_sample_rate(default_config, test_sdk):
|
||||
"""Test connecting with float sample rate."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = int(config.sample_rate) - 0.5
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
assert isinstance(microphone.sample_rate, int)
|
||||
assert microphone.sample_rate == int(config.sample_rate)
|
||||
|
||||
|
||||
def test_connect_lower_sample_rate(default_config, test_sdk):
|
||||
"""Test connecting with lower sample rate."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = 1000 # Lowest possible sample rate
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
microphone.connect()
|
||||
assert microphone.sample_rate == 1000
|
||||
|
||||
|
||||
def test_connect_invalid_channels(default_config, test_sdk):
|
||||
"""Test connecting with invalid channels."""
|
||||
config = deepcopy(default_config)
|
||||
config.channels = np.append(default_config.channels, -1)
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_disconnect_success(default_config, test_sdk):
|
||||
"""Test successful disconnection."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.disconnect()
|
||||
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_disconnect_not_connected(default_config, test_sdk):
|
||||
"""Test disconnecting when not connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
microphone.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_recording_success(default_config, test_sdk, multiprocessing):
|
||||
"""Test successful recording start."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_recording_not_connected(default_config, test_sdk, multiprocessing):
|
||||
"""Test starting recording when not connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_recording_already_recording(default_config, test_sdk, multiprocessing):
|
||||
"""Test starting recording when already recording."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
with pytest.raises(DeviceAlreadyRecordingError):
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test successful writing start."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert microphone.is_writing
|
||||
assert (tmp_path / "test.wav").exists()
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_writing_file_already_exists_no_overwrite(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test writing with file that already exists."""
|
||||
(tmp_path / "test.wav").touch()
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
microphone.start_recording(
|
||||
output_file=tmp_path / "test.wav", multiprocessing=multiprocessing, overwrite=False
|
||||
)
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_stop_recording_success(default_config, test_sdk, multiprocessing):
|
||||
"""Test successful recording stop."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.stop_recording()
|
||||
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_stop_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test successful writing stop."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.stop_recording()
|
||||
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
assert (tmp_path / "test.wav").exists()
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
def test_stop_recording_not_connected(default_config, test_sdk):
|
||||
"""Test stopping recording when not connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
def test_stop_recording_not_recording(default_config, test_sdk):
|
||||
"""Test stopping recording when not recording."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
with pytest.raises(DeviceNotRecordingError):
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_disconnect_while_recording(default_config, test_sdk, multiprocessing):
|
||||
"""Test disconnecting while recording."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.disconnect()
|
||||
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_disconnect_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test disconnecting while writing."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.disconnect()
|
||||
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
assert Path(tmp_path / "test.wav").exists()
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_read_success(default_config, test_sdk, multiprocessing):
|
||||
"""Test successful reading of audio data."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
|
||||
data = microphone.read()
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert data is not None
|
||||
assert data.shape[1] == len(default_config.channels)
|
||||
assert (
|
||||
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test successful writing to file."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
|
||||
microphone.stop_recording()
|
||||
|
||||
data, samplerate = read(tmp_path / "test.wav")
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert samplerate == default_config.sample_rate
|
||||
assert data.shape[1] == len(default_config.channels)
|
||||
assert (
|
||||
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_read_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test reading while writing."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
|
||||
read_data = microphone.read()
|
||||
microphone.stop_recording()
|
||||
|
||||
writing_data, _ = read(tmp_path / "test.wav")
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert (
|
||||
abs(writing_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
assert (
|
||||
abs(read_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
def test_async_start_recording(default_config, test_sdk):
|
||||
"""Test async recording start."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(microphones)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_async_start_writing(tmp_path, default_config, test_sdk):
|
||||
"""Test async writing start."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(
|
||||
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
|
||||
)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert microphone.is_writing
|
||||
assert Path(tmp_path / "test_1.wav").exists()
|
||||
assert Path(tmp_path / "test_2.wav").exists()
|
||||
|
||||
(tmp_path / "test_1.wav").unlink()
|
||||
(tmp_path / "test_2.wav").unlink()
|
||||
|
||||
|
||||
def test_async_stop_recording(default_config, test_sdk):
|
||||
"""Test async recording stop."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(microphones)
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_async_stop_writing(tmp_path, default_config, test_sdk):
|
||||
"""Test async writing stop."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(
|
||||
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
|
||||
)
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
assert Path(tmp_path / "test_1.wav").exists()
|
||||
assert Path(tmp_path / "test_2.wav").exists()
|
||||
|
||||
(tmp_path / "test_1.wav").unlink()
|
||||
(tmp_path / "test_2.wav").unlink()
|
||||
@@ -0,0 +1,508 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
from multiprocessing import Event, Process, Queue
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
|
||||
def writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||
"""Writer process that continuously writes data to shared array."""
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Wait for all processes to be ready
|
||||
barrier.wait()
|
||||
|
||||
write_count = 0
|
||||
while not stop_event.is_set() and write_count < 10:
|
||||
# Generate unique data for this process and write iteration
|
||||
data = np.full((5, 2), process_id * 100 + write_count, dtype=np.float32)
|
||||
|
||||
try:
|
||||
shared_array.write(local_array, data)
|
||||
data_queue.put(f"writer_{process_id}_wrote_{write_count}")
|
||||
write_count += 1
|
||||
time.sleep(0.01) # Small delay to allow race conditions
|
||||
except IndexError:
|
||||
# Array is full, stop writing
|
||||
break
|
||||
|
||||
|
||||
def reader_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||
"""Reader process that continuously reads data from shared array."""
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Wait for all processes to be ready
|
||||
barrier.wait()
|
||||
|
||||
read_count = 0
|
||||
while not stop_event.is_set() and read_count < 5:
|
||||
time.sleep(0.02) # Allow some writes to accumulate
|
||||
|
||||
data = shared_array.read(local_array, flush=True)
|
||||
data_queue.put(f"reader_{process_id}_read_{len(data)}_items")
|
||||
read_count += 1
|
||||
|
||||
|
||||
def stress_writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||
"""High-frequency writer process for stress testing."""
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
barrier.wait()
|
||||
|
||||
write_count = 0
|
||||
while not stop_event.is_set() and write_count < 50:
|
||||
# Write single row at a time for more frequent operations
|
||||
data = np.array([[process_id, write_count]], dtype=np.float32)
|
||||
|
||||
try:
|
||||
shared_array.write(local_array, data)
|
||||
write_count += 1
|
||||
# No sleep - stress test
|
||||
except IndexError:
|
||||
break
|
||||
|
||||
data_queue.put(f"stress_writer_{process_id}_completed_{write_count}")
|
||||
|
||||
|
||||
# Basic functionality tests
|
||||
|
||||
|
||||
def test_shared_array_creation():
|
||||
"""Test basic SharedArray creation and properties."""
|
||||
shape = (100, 4)
|
||||
dtype = np.float32
|
||||
|
||||
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||
|
||||
assert shared_array.shape == shape
|
||||
assert shared_array.dtype == dtype
|
||||
assert shared_array.read_index.value == 0
|
||||
|
||||
# Clean up
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_local_array_access():
|
||||
"""Test getting local array instances."""
|
||||
shape = (50, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
assert local_array.shape == shape
|
||||
assert local_array.dtype == np.float32
|
||||
assert isinstance(local_array, np.ndarray)
|
||||
|
||||
# Test that we can get multiple local array instances
|
||||
local_array2 = shared_array.get_local_array()
|
||||
assert local_array2.shape == shape
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_write_and_read_single_process():
|
||||
"""Test basic write and read operations in single process."""
|
||||
shape = (20, 3)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Write some data
|
||||
data1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
shared_array.write(local_array, data1)
|
||||
|
||||
assert shared_array.read_index.value == 2
|
||||
|
||||
# Write more data
|
||||
data2 = np.array([[7, 8, 9]], dtype=np.float32)
|
||||
shared_array.write(local_array, data2)
|
||||
|
||||
assert shared_array.read_index.value == 3
|
||||
|
||||
# Read all data
|
||||
read_data = shared_array.read(local_array, flush=False)
|
||||
expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
|
||||
np.testing.assert_array_equal(read_data, expected)
|
||||
|
||||
# Read with flush
|
||||
read_data_flush = shared_array.read(local_array, flush=True)
|
||||
np.testing.assert_array_equal(read_data_flush, expected)
|
||||
assert shared_array.read_index.value == 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_array_overflow():
|
||||
"""Test behavior when writing more data than array capacity."""
|
||||
shape = (5, 2) # Small array
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Fill the array
|
||||
data = np.ones((5, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, data)
|
||||
|
||||
# Try to write more data - should raise IndexError
|
||||
with pytest.raises(ValueError):
|
||||
extra_data = np.ones((2, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, extra_data)
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_reset_functionality():
|
||||
"""Test the reset method."""
|
||||
shape = (10, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Write some data
|
||||
data = np.ones((3, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, data)
|
||||
assert shared_array.read_index.value == 3
|
||||
|
||||
# Reset
|
||||
shared_array.reset()
|
||||
assert shared_array.read_index.value == 0
|
||||
|
||||
# Read should return empty array
|
||||
read_data = shared_array.read(local_array, flush=False)
|
||||
assert len(read_data) == 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
# Multi-process tests
|
||||
|
||||
|
||||
def test_single_writer_single_reader():
|
||||
"""Test basic writer-reader scenario with one process each."""
|
||||
shape = (100, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
barrier = multiprocessing.Barrier(2) # Writer + reader
|
||||
|
||||
# Start writer process
|
||||
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
|
||||
# Start reader process
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
|
||||
writer.start()
|
||||
reader.start()
|
||||
|
||||
# Let them run for a bit
|
||||
time.sleep(0.5)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for completion
|
||||
writer.join(timeout=2.0)
|
||||
reader.join(timeout=2.0)
|
||||
|
||||
# Verify both processes completed
|
||||
assert not writer.is_alive()
|
||||
assert not reader.is_alive()
|
||||
|
||||
# Check that we got messages from both processes
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||
|
||||
assert len(writer_messages) > 0
|
||||
assert len(reader_messages) > 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_multiple_writers_single_reader():
|
||||
"""Test multiple writers with single reader - check for race conditions."""
|
||||
shape = (200, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
num_writers = 3
|
||||
barrier = multiprocessing.Barrier(num_writers + 1) # Writers + reader
|
||||
|
||||
processes = []
|
||||
|
||||
# Start multiple writer processes
|
||||
for i in range(num_writers):
|
||||
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||
processes.append(writer)
|
||||
writer.start()
|
||||
|
||||
# Start reader process
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
processes.append(reader)
|
||||
reader.start()
|
||||
|
||||
# Let them run
|
||||
time.sleep(1.0)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for all processes
|
||||
for process in processes:
|
||||
process.join(timeout=3.0)
|
||||
assert not process.is_alive()
|
||||
|
||||
# Verify we got messages from all processes
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||
|
||||
# Should have messages from all writers
|
||||
assert len(writer_messages) >= num_writers
|
||||
assert len(reader_messages) > 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_data_integrity_with_concurrent_access():
|
||||
"""Test that data integrity is maintained under concurrent access using standard reader/writer processes."""
|
||||
shape = (500, 2) # Use standard 2-column format
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
barrier = multiprocessing.Barrier(3) # 2 writers + 1 reader
|
||||
|
||||
# Start two writer processes
|
||||
writer1 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
writer2 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 2))
|
||||
|
||||
# Start one reader process
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
|
||||
writer1.start()
|
||||
writer2.start()
|
||||
reader.start()
|
||||
|
||||
# Let them run for integrity test duration
|
||||
time.sleep(1.0)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for completion
|
||||
writer1.join(timeout=3.0)
|
||||
writer2.join(timeout=3.0)
|
||||
reader.join(timeout=3.0)
|
||||
|
||||
# Verify all processes completed successfully
|
||||
assert not writer1.is_alive()
|
||||
assert not writer2.is_alive()
|
||||
assert not reader.is_alive()
|
||||
|
||||
# Verify data integrity by checking messages
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
writer1_messages = [msg for msg in messages if "writer_1_wrote" in msg]
|
||||
writer2_messages = [msg for msg in messages if "writer_2_wrote" in msg]
|
||||
reader_messages = [msg for msg in messages if "reader_1_read" in msg]
|
||||
|
||||
# Verify both writers wrote data
|
||||
assert len(writer1_messages) > 0
|
||||
assert len(writer2_messages) > 0
|
||||
# Verify reader read data
|
||||
assert len(reader_messages) > 0
|
||||
|
||||
# Verify the shared array is in a consistent state
|
||||
local_array = shared_array.get_local_array()
|
||||
final_data = shared_array.read(local_array, flush=False)
|
||||
|
||||
# Should have some data written by the writers
|
||||
assert len(final_data) >= 0 # Could be empty if reader flushed everything
|
||||
# Should not exceed array capacity
|
||||
assert len(final_data) <= shape[0]
|
||||
|
||||
# If there's data, verify it contains the expected writer signatures
|
||||
if len(final_data) > 0:
|
||||
# Data should contain values like 100, 101, 102... (writer 1) or 200, 201, 202... (writer 2)
|
||||
unique_values = np.unique(final_data.flatten())
|
||||
writer1_values = unique_values[(unique_values >= 100) & (unique_values < 200)]
|
||||
writer2_values = unique_values[(unique_values >= 200) & (unique_values < 300)]
|
||||
|
||||
# Should have data from at least one writer
|
||||
assert len(writer1_values) > 0 or len(writer2_values) > 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_stress_test_high_frequency_operations():
|
||||
"""Stress test with high frequency read/write operations."""
|
||||
shape = (1000, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
num_writers = 4
|
||||
barrier = multiprocessing.Barrier(num_writers)
|
||||
|
||||
processes = []
|
||||
|
||||
# Start multiple high-frequency writers
|
||||
for i in range(num_writers):
|
||||
writer = Process(
|
||||
target=stress_writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)
|
||||
)
|
||||
processes.append(writer)
|
||||
writer.start()
|
||||
|
||||
# Let them run for stress test duration
|
||||
time.sleep(0.5)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for completion
|
||||
for process in processes:
|
||||
process.join(timeout=3.0)
|
||||
assert not process.is_alive()
|
||||
|
||||
# Verify all writers completed successfully
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
completed_messages = [msg for msg in messages if "completed" in msg]
|
||||
assert len(completed_messages) == num_writers
|
||||
|
||||
# Verify the shared array is in a consistent state
|
||||
local_array = shared_array.get_local_array()
|
||||
final_data = shared_array.read(local_array, flush=False)
|
||||
|
||||
# Should have some data written
|
||||
assert len(final_data) > 0
|
||||
# Should not exceed array capacity
|
||||
assert len(final_data) <= shape[0]
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_concurrent_readers():
|
||||
"""Test multiple concurrent readers with writers to ensure thread safety."""
|
||||
shape = (200, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
num_readers = 3
|
||||
num_writers = 2
|
||||
barrier = multiprocessing.Barrier(num_readers + num_writers)
|
||||
|
||||
processes = []
|
||||
|
||||
# Start multiple writer processes to generate data
|
||||
for i in range(num_writers):
|
||||
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||
processes.append(writer)
|
||||
writer.start()
|
||||
|
||||
# Start multiple reader processes
|
||||
for i in range(num_readers):
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||
processes.append(reader)
|
||||
reader.start()
|
||||
|
||||
# Let them run to test concurrent access
|
||||
time.sleep(1.0)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for process in processes:
|
||||
process.join(timeout=3.0)
|
||||
assert not process.is_alive()
|
||||
|
||||
# Verify all readers and writers completed
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||
|
||||
# Should have messages from all readers and writers
|
||||
assert len(reader_messages) >= num_readers
|
||||
assert len(writer_messages) >= num_writers
|
||||
|
||||
# Verify different readers generated different messages (proving they ran concurrently)
|
||||
reader_ids = set()
|
||||
for msg in reader_messages:
|
||||
# Extract reader ID from message like "reader_1_read_5_items"
|
||||
parts = msg.split("_")
|
||||
if len(parts) >= 2:
|
||||
reader_ids.add(parts[1])
|
||||
|
||||
assert len(reader_ids) == num_readers # All readers should have participated
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_edge_case_empty_reads():
|
||||
"""Test reading from empty array and after flushes."""
|
||||
shape = (10, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Read from empty array
|
||||
empty_data = shared_array.read(local_array, flush=False)
|
||||
assert len(empty_data) == 0
|
||||
|
||||
# Write some data
|
||||
data = np.ones((3, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, data)
|
||||
|
||||
# Read with flush
|
||||
read_data = shared_array.read(local_array, flush=True)
|
||||
assert len(read_data) == 3
|
||||
|
||||
# Read again after flush - should be empty
|
||||
empty_again = shared_array.read(local_array, flush=False)
|
||||
assert len(empty_again) == 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_different_dtypes():
|
||||
"""Test SharedArray with different numpy dtypes."""
|
||||
dtypes_to_test = [np.float32, np.float64, np.int32, np.int16]
|
||||
|
||||
for dtype in dtypes_to_test:
|
||||
shape = (20, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
assert local_array.dtype == dtype
|
||||
|
||||
# Write and read data of this dtype
|
||||
data = np.ones((5, 2), dtype=dtype)
|
||||
shared_array.write(local_array, data)
|
||||
|
||||
read_data = shared_array.read(local_array, flush=True)
|
||||
assert read_data.dtype == dtype
|
||||
assert len(read_data) == 5
|
||||
|
||||
shared_array.delete()
|
||||
@@ -22,7 +22,7 @@ from lerobot.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.motors.motors_bus import Motor, MotorNormMode
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots import Robot, RobotConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from tests.mocks.mock_motors_bus import MockMotorsBus
|
||||
|
||||
|
||||
@@ -98,10 +98,8 @@ class MockRobot(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
@@ -110,19 +108,15 @@ class MockRobot(Robot):
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
@check_if_not_connected
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
@@ -130,14 +124,10 @@ class MockRobot(Robot):
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("mock_teleop")
|
||||
@@ -68,10 +68,8 @@ class MockTeleop(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
@@ -80,19 +78,15 @@ class MockTeleop(Teleoperator):
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
@check_if_not_connected
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
@@ -100,12 +94,9 @@ class MockTeleop(Teleoperator):
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None: ...
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
|
||||
@@ -64,7 +64,7 @@ def close_service_stub(channel, server):
|
||||
server.stop(None)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@require_package("grpcio", "grpc")
|
||||
def test_establish_learner_connection_success():
|
||||
from lerobot.rl.actor import establish_learner_connection
|
||||
|
||||
@@ -81,7 +81,7 @@ def test_establish_learner_connection_success():
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@require_package("grpcio", "grpc")
|
||||
def test_establish_learner_connection_failure():
|
||||
from lerobot.rl.actor import establish_learner_connection
|
||||
|
||||
@@ -100,7 +100,7 @@ def test_establish_learner_connection_failure():
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@require_package("grpcio", "grpc")
|
||||
def test_push_transitions_to_transport_queue():
|
||||
from lerobot.rl.actor import push_transitions_to_transport_queue
|
||||
from lerobot.transport.utils import bytes_to_transitions
|
||||
@@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue():
|
||||
assert_transitions_equal(deserialized_transition, transitions[i])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@require_package("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_transitions_stream():
|
||||
from lerobot.rl.actor import transitions_stream
|
||||
@@ -167,7 +167,7 @@ def test_transitions_stream():
|
||||
assert streamed_data[2].data == b"transition_data_3"
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@require_package("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_interactions_stream():
|
||||
from lerobot.rl.actor import interactions_stream
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user